# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import builtins
import collections
import inspect
import itertools
import math
import operator
import warnings
from collections.abc import Iterable, Sequence
from enum import Enum
from functools import partial, reduce, singledispatch, wraps
from typing import Any, Callable, cast, Optional, overload, Union

import torch
import torch._prims as prims
import torch._prims_common as utils
import torch.utils._pytree as pytree
from torch import sym_float, sym_int
from torch._prims_common import (
    BoolLike,
    DeviceLikeType,
    Dim,
    DimsSequenceType,
    DimsType,
    dtype_to_type,
    ELEMENTWISE_TYPE_PROMOTION_KIND,
    FloatLike,
    FloatWithoutSymFloat,
    IntLike,
    is_weakly_lesser_type,
    Number,
    NumberType,
    RealNumberType,
    REDUCTION_OUTPUT_TYPE_KIND,
    ShapeType,
    StrideType,
    TensorLike,
    TensorLikeType,
    TensorOrNumberLikeType,
    TensorSequenceType,
)
from torch._prims_common.wrappers import (
    _maybe_convert_to_dtype,
    _maybe_resize_out,
    _safe_copy_out,
    elementwise_type_promotion_wrapper,
    elementwise_unary_scalar_wrapper,
    out_wrapper,
)


# Experimental module containing prototype Python references for existing
#   PyTorch operations.

__all__ = [
    #
    # Elementwise Unary References
    #
    "abs",
    "acos",
    "acosh",
    "asinh",
    "asin",
    "atan",
    "atanh",
    "bitwise_not",
    # "cbrt",  # No corresponding torch operation
    "ceil",
    "conj_physical",
    "cos",
    "cosh",
    "count_nonzero",
    "deg2rad",
    "digamma",
    "erf",
    "erfinv",
    "erfc",
    "exp",
    "expm1",
    "exponential",
    "exp2",
    "fill",
    "fill_",
    "floor",
    "frac",
    "geometric",
    "index_add",
    "index_copy",
    "index_copy_",
    "index_select",
    "index_fill",
    "index_fill_",
    "isfinite",
    "isinf",
    "isposinf",
    "isneginf",
    "isnan",
    "isreal",
    "i0",
    "lerp",
    "lgamma",
    "log",
    "log1p",
    "log2",
    "log10",
    "log_normal",
    "log_softmax",
    "mvlgamma",
    "norm",
    "normal",
    "nan_to_num",
    "neg",
    "positive",
    "rad2deg",
    "reciprocal",
    "round",  # TODO: model kwargs
    "sigmoid",
    "sgn",
    "sign",
    "signbit",
    "sin",
    "sinc",
    "sinh",
    "softmax",
    "sqrt",
    "square",
    "tan",
    "tanh",
    "trace",
    "trunc",
    #
    # Elementwise Binary References
    #
    "add",
    "atan2",
    "bitwise_and",
    "bitwise_left_shift",
    "bitwise_or",
    "bitwise_right_shift",
    "bitwise_xor",
    "clamp_min",
    "clamp_max",
    "copysign",
    "div",
    "eq",
    "float_power",
    "floor_divide",
    "fmax",
    "fmin",
    "fmod",
    "gcd",
    "ge",
    "gt",
    "heaviside",
    "hypot",
    "igamma",
    "igammac",
    "imag",
    "isclose",
    "lcm",
    # 'ldexp',
    "le",
    "logaddexp",
    "logaddexp2",
    "logical_and",
    "logical_not",
    "logical_or",
    "logical_xor",
    "logsumexp",
    "lt",
    # 'max', # implement with reductions
    "maximum",
    # 'min', # implement with reductions
    "minimum",
    "mul",
    "ne",
    "nextafter",
    # 'polar',  # abs, cos, sin
    "pow",
    "real",
    "rpow",
    "remainder",
    "rsub",
    "rtruediv",
    "rfloordiv",
    "sub",
    "true_divide",
    "trunc_divide",
    "xlogy",
    #
    # Elementwise Ternary References
    #
    "addcdiv",
    "addcmul",
    "clamp",
    #
    # Conditional references
    #
    "masked_fill",
    "masked_fill_",
    "where",
    #
    # Data conversion and movement references
    #
    "clone",
    "copy_to",  # TODO: add OpInfo (or implement .to)
    "item",
    "to",
    #
    # Reduction ops
    #
    "all",
    "amax",
    "amin",
    "any",
    "cumsum",
    "cumprod",
    "mean",
    "dot",
    "vdot",
    "std",
    "std_mean",
    "sum",
    "sum_to_size",
    "prod",
    "var",
    "var_mean",
    #
    # Linear algebra ops
    #
    "addr",
    #
    # View & Shape Ops
    #
    "alias",
    "alias_copy",
    "atleast_1d",
    "atleast_2d",
    "atleast_3d",
    "as_strided",
    "as_strided_copy",
    "as_strided_scatter",
    "block_diag",
    "broadcast_shapes",
    "broadcast_tensors",
    "broadcast_to",
    "cat",
    "chunk",
    "column_stack",
    "conj",
    "constant_pad_nd",
    "contiguous",
    "diag_embed",
    "diag",
    "diagonal",
    "diagonal_copy",
    "diagonal_scatter",
    "dsplit",
    "dstack",
    "expand",
    "expand_as",
    "expand_copy",
    "flatten",
    "flip",
    "fliplr",
    "flipud",
    "hsplit",
    "hstack",
    "meshgrid",
    "movedim",
    "narrow",
    "narrow_copy",
    "native_group_norm",
    "native_layer_norm",
    "permute",
    "permute_copy",
    "ravel",
    "repeat",
    "reshape",
    "reshape_as",
    "roll",
    "rot90",
    "rsqrt",
    "split_with_sizes",
    "stack",
    "swap_axes",  # alias for transpose
    "squeeze",
    "squeeze_copy",
    "t",
    "t_copy",
    "T",
    "take_along_dim",
    "tensor_split",
    "transpose",
    "transpose_copy",
    "unbind_copy",
    "unfold",
    "unfold_copy",
    "unsqueeze",
    "unsqueeze_copy",
    "view",
    "view_as",
    "view_copy",
    "vsplit",
    "vstack",
    "view_as_complex",
    "unflatten",
    "unbind",
    "triu",
    "tril",
    "triu_indices",
    "tril_indices",
    #
    # Tensor Creation
    #
    "arange",
    "cauchy",
    "empty",
    "empty_like",
    "empty_permuted",
    "empty_strided",
    "eye",
    "full",
    "full_like",
    "linspace",
    "logspace",
    "new_empty",
    "new_empty_strided",
    "new_full",
    "new_ones",
    "new_zeros",
    "ones",
    "ones_like",
    "randn",
    "scalar_tensor",
    "zero",
    "zeros",
    "zeros_like",
    #
    # Test-related functions
    #
    "allclose",
    "equal",
    #
    # Statistical operations
    #
    "bucketize",
    #
    # Misc
    #
    "is_complex",
    "renorm",
    "stft",
    "istft",
]

Tensor = torch.Tensor
DispatchKey = torch._C.DispatchKey  # type: ignore[attr-defined]
aten = torch._ops.ops.aten

# Note that the docstrings for the public methods from this file are in
# torch/_torch_docs.py


def is_noncontiguous_supported(device):
    return device is None or device.type != "hpu"


def handle_noncontiguous_outputs(input_tlist, output):
    device = None
    from torch._subclasses.fake_tensor import FakeTensor

    for t in input_tlist:
        if isinstance(t, FakeTensor):
            device = t.fake_device
            break

    if not is_noncontiguous_supported(device):
        output = output.contiguous()

    return output


def _broadcast_shapes(*_shapes):
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

    shapes = tuple(
        (x,) if isinstance(x, IntLike) else x
        for x in filter(lambda x: x is not None, _shapes)
    )

    # Short-circuits on no input
    if len(shapes) == 0:
        return None

    # Type checking
    # TODO: make common validations available as utils
    for shape in shapes:
        assert isinstance(shape, Sequence)

    # Computes common shape
    common_shape: list[Union[int, torch.SymInt]] = [
        1,
    ] * reduce(max, (len(shape) for shape in shapes))
    for arg_idx, shape in enumerate(shapes):
        for idx in range(-1, -1 - len(shape), -1):
            if guard_size_oblivious(common_shape[idx] == 1):
                if shape[idx] < 0:
                    raise ValueError(
                        "Attempting to broadcast a dimension with negative length!"
                    )
                common_shape[idx] = shape[idx]
            elif guard_size_oblivious(shape[idx] != 1):
                torch._check(
                    common_shape[idx] == shape[idx],
                    lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! "
                    f"Mismatching argument at index {arg_idx} had {shape}; but expected shape "
                    f"should be broadcastable to {common_shape}",
                )

    return common_shape


def _maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
    # Computes common shape
    common_shape = _broadcast_shapes(
        *(t.shape if isinstance(t, TensorLike) else None for t in args)
    )

    def __maybe_broadcast(x, shape):
        if x is None:
            return None
        elif isinstance(x, Number):
            return x
        elif isinstance(x, TensorLike):
            if preserve_cpu_scalar_tensors and utils.is_cpu_scalar_tensor(x):
                return x

            if not utils.same_shape(x.shape, common_shape):
                return x.expand(common_shape)

            return x
        else:
            raise RuntimeError(
                "Unexpected type when broadcasting: " + str(type(x)) + "!"
            )

    return tuple(__maybe_broadcast(x, common_shape) for x in args)


# Utilities should come BEFORE this import
from torch._decomp import register_decomposition


#
# Elementwise unary references
#

infer_aten_op = object()


# TODO: add type promotion support
def _make_elementwise_unary_reference(
    type_promotion_kind,
    *,
    aten_op=infer_aten_op,
    extra_meta=None,
    exact_dtype=False,
) -> Callable:
    def inner(prim: Callable):
        nonlocal aten_op

        @wraps(prim)
        @out_wrapper(exact_dtype=exact_dtype)
        @elementwise_unary_scalar_wrapper
        @elementwise_type_promotion_wrapper(
            type_promoting_args=("a",),
            type_promotion_kind=type_promotion_kind,
        )
        def _ref(a: TensorLikeType) -> TensorLikeType:
            if extra_meta is not None:
                extra_meta(a)

            output = prim(a)
            return handle_noncontiguous_outputs([a], output)

        if aten_op is infer_aten_op:
            aten_op = utils.get_aten_op(prim, prim.__name__)
        if aten_op is not None:
            register_decomposition(aten_op)(_ref)

        return _ref

    return inner


def _make_alias(fn, name):
    """
    This function defines an alias of another function and sets its __name__ argument.
    It also sets its __module__ argument to the module of the caller.
    Note that when naively doing `alias = fn`, we have that `alias.__name__ == "fn"`, and
    `alias.__module__ == fn.__module__`.
    """

    def _fn(*args, **kwargs):
        return fn(*args, **kwargs)

    _fn.__name__ = name
    _fn.__module__ = inspect.currentframe().f_back.f_globals["__name__"]  # type: ignore[union-attr]
    return _fn


def _make_inplace(fn):
    """
    Given a function with out variant (i.e. using `out_wrapper()), it returns its in-place variant
    See https://github.com/pytorch/pytorch/wiki/Developer-FAQ#how-do-in-place-operations-work-in-pytorch
    """

    # nb. We use the name of the first argument used in the unary references
    @wraps(fn)
    def _fn(a, *args, **kwargs):
        return fn(a, *args, out=a, **kwargs)

    inplace_name = f"{fn.__name__}_"
    _fn.__name__ = inplace_name
    _fn = register_decomposition(getattr(aten, inplace_name))(_fn)  # type: ignore[assignment]

    # We access the __all__ attribute of the module where fn is defined
    # There may be a cleaner way of doing this...
    from inspect import getmodule

    _all = getmodule(fn).__all__  # type: ignore[union-attr]
    if inplace_name not in _all:
        _all.append(inplace_name)
    return _fn


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
    exact_dtype=True,
)
def abs(a):
    return prims.abs(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def acos(a):
    return prims.acos(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def acosh(a):
    return prims.acosh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def asin(a):
    return prims.asin(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def asinh(a):
    return prims.asinh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def atan(a):
    return prims.atan(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def atanh(a):
    return prims.atanh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
def bitwise_not(a):
    return prims.bitwise_not(a)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    exact_dtype=True,
)
def ceil(a):
    return prims.ceil(a)


@register_decomposition(aten.is_complex)
def is_complex(input: TensorLikeType):
    return utils.is_complex_dtype(input.dtype)


@register_decomposition(aten.conj_physical)
@out_wrapper()
def conj_physical(input: TensorLikeType):
    if not utils.is_complex_dtype(input.dtype):
        return input
    return prims.conj_physical(input)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def cos(a):
    return prims.cos(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def cosh(a):
    return prims.cosh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def digamma(a):
    return prims.digamma(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def erf(a):
    return prims.erf(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def erfinv(a):
    return prims.erf_inv(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def erfc(a):
    return prims.erfc(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def exp(a):
    return prims.exp(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def expm1(a):
    return prims.expm1(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def exp2(a):
    return prims.exp2(a)


# Fill has its own implementation because it has a value parameter
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a,"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def fill(a: TensorLikeType, value: NumberType) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    assert isinstance(value, Number)

    python_type = utils.dtype_to_type(a.dtype)
    if not utils.is_weakly_lesser_type(type(value), python_type):
        msg = f"value argument of type {type(value)} cannot be safely cast to type {python_type}!"
        raise ValueError(msg)

    return prims.fill(a, value)


def fill_(a: TensorLikeType, value: NumberType) -> TensorLikeType:
    r = prims.fill(a, value)
    prims.copy_to(a, r)
    return a


@register_decomposition(aten.zero)
@out_wrapper()
def zero(input: TensorLikeType) -> TensorLikeType:
    return torch.zeros_like(input)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    exact_dtype=True,
)
def floor(a):
    return prims.floor(a)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    exact_dtype=True,
)
def frac(x: TensorLikeType) -> TensorLikeType:
    trunc_x = torch.mul(torch.floor(torch.abs(x)), torch.sign(x))
    return torch.sub(x, trunc_x)


# imag does not use _make_elementwise_unary_reference because it does not support out
def imag(a: TensorLikeType) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    torch._check(
        utils.is_complex_dtype(a.dtype), lambda: "imag only supports complex tensors."
    )
    return prims.imag(a)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    aten_op=None,  # CompositeImplicitAutograd
)
def isfinite(a: TensorLikeType) -> TensorLikeType:
    if utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype):
        return prims.isfinite(a)

    return ones_like(a, dtype=torch.bool)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isinf(a: TensorLikeType) -> TensorLikeType:
    if utils.is_complex_dtype(a.dtype):
        return torch.logical_or(isinf(torch.real(a)), isinf(torch.imag(a)))
    if utils.is_float_dtype(a.dtype):
        return torch.abs(a) == float("inf")
    return torch.zeros_like(a, dtype=torch.bool)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    exact_dtype=True,
)
def isposinf(a: TensorLikeType) -> TensorLikeType:
    torch._check(
        not utils.is_complex_dtype(a.dtype),
        lambda: f"Complex dtype is not supported for isposinf, got dtype {a.dtype}",
    )
    if utils.is_float_dtype(a.dtype):
        return a == float("inf")
    return torch.zeros_like(a, dtype=torch.bool)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    exact_dtype=True,
)
def isneginf(a: TensorLikeType) -> TensorLikeType:
    torch._check(
        not utils.is_complex_dtype(a.dtype),
        lambda: f"Complex dtype is not supported for isneginf, got dtype {a.dtype}",
    )
    if utils.is_float_dtype(a.dtype):
        return a == float("-inf")
    return torch.zeros_like(a, dtype=torch.bool)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def isnan(a: TensorLikeType) -> TensorLikeType:
    return prims.ne(a, a)


# alias
mvlgamma = _make_alias(torch.special.multigammaln, "mvlgamma")  # type: ignore[has-type]


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    aten_op=None,  # CompositeImplicitAutograd
)
def isreal(a: TensorLikeType) -> TensorLikeType:
    if utils.is_complex_dtype(a.dtype):
        return torch.imag(a) == 0
    return torch.ones_like(a, dtype=torch.bool)


# TODO: if this is special maybe it should be defined there and imported here?
@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, aten_op=aten.i0
)
def i0(a):
    return prims.bessel_i0(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def lgamma(a):
    return prims.lgamma(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log(a):
    return prims.log(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log1p(a):
    return prims.log1p(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log2(a):
    return prims.log2(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def log10(a):
    return prims.log10(a)


# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def log_softmax(
    a: TensorLikeType,
    dim: int,
    dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
    result_dtype = dtype or a.dtype
    computation_dtype = utils.get_computation_dtype(result_dtype)
    a_ = _maybe_convert_to_dtype(a, computation_dtype)
    return _maybe_convert_to_dtype(a_ - logsumexp(a_, dim, keepdim=True), result_dtype)  # type: ignore[return-value]


@register_decomposition(aten.logsumexp)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def logsumexp(
    self: TensorLikeType, dim: DimsType, keepdim: bool = False
) -> TensorLikeType:
    if not isinstance(dim, Iterable):
        dim = (dim,)
    if self.numel() == 0:
        return torch.sum(torch.exp(self), dim, keepdim).log()
    maxes = torch.amax(torch.real(self), dim, keepdim=True)
    maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
    maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
    result = torch.sum(torch.exp(self - maxes), dim, keepdim)
    return result.log().add(maxes_squeezed)


@register_decomposition(aten.nan_to_num)
@out_wrapper()
def nan_to_num(
    a: TensorLikeType,
    nan: Optional[NumberType] = 0.0,
    posinf: Optional[NumberType] = None,
    neginf: Optional[NumberType] = None,
) -> TensorLikeType:
    assert isinstance(a, TensorLike)

    if utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
        return a.clone()

    if nan is None:
        nan = 0.0

    if posinf is None:
        posinf = torch.finfo(a.dtype).max

    if neginf is None:
        neginf = torch.finfo(a.dtype).min

    result = torch.where(torch.isnan(a), nan, a)  # type: ignore[call-overload]
    result = torch.where(torch.isneginf(a), neginf, result)  # type: ignore[call-overload]
    result = torch.where(torch.isposinf(a), posinf, result)  # type: ignore[call-overload]
    return result


def _neg_meta(a: TensorLikeType):
    torch._check(
        a.dtype is not torch.bool,
        lambda: (
            "Negation, the `-` operator, on a bool tensor is not supported. "
            "If you are trying to invert a mask, use the `~` or `logical_not()` "
            "operator instead."
        ),
    )


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, extra_meta=_neg_meta
)
def neg(a):
    return prims.neg(a)


# positive does not use _make_elementwise_unary_reference because it does not support out
# CompositeImplicitAutograd - don't register decomp
def positive(a: TensorLikeType) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    if a.dtype is torch.bool:
        msg = "positive does not support bool tensors."
        raise RuntimeError(msg)
    return a


# real does not use _make_elementwise_unary_reference because it does not support out
def real(a: TensorLikeType) -> TensorLikeType:
    assert isinstance(a, TensorLike)
    if utils.is_complex_dtype(a.dtype):
        return prims.real(a)
    return a


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def reciprocal(a):
    return prims.reciprocal(a)


@register_decomposition(aten.round)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def round(a: TensorLikeType, *, decimals: int = 0) -> TensorLikeType:
    if decimals == 0:
        return prims.round(a)
    else:
        ten_pow = 10**decimals
        ten_neg_pow = 10 ** (-decimals)
        return prims.mul(prims.round(prims.mul(a, ten_pow)), ten_neg_pow)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def rsqrt(a):
    return prims.rsqrt(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sigmoid(a: TensorLikeType) -> TensorLikeType:
    return true_divide(1, add(1, exp(neg(a))))


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    exact_dtype=True,
)
def sgn(a):
    if utils.is_complex_dtype(a.dtype):
        a_abs = a.abs()
        return torch.where(a_abs == 0, 0, a / a_abs)
    else:
        return a.sign()


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    exact_dtype=True,
)
def sign(a):
    return prims.sign(a)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    exact_dtype=True,
)
def signbit(a):
    return prims.signbit(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sin(a):
    return prims.sin(a)


# Autograd note: This will give the right first derivative at zero (by chance),
# but not the right second derivative
@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sinc(a):
    a = math.pi * a
    return torch.where(a == 0, 1, torch.sin(a) / a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sinh(a):
    return prims.sinh(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def sqrt(a):
    return prims.sqrt(a)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
    aten_op=None,  # CompositeImplicitAutograd,
)
def square(a: TensorLikeType) -> TensorLikeType:
    return mul(a, a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def tan(a):
    return prims.tan(a)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def tanh(a):
    return prims.tanh(a)


@_make_elementwise_unary_reference(
    ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    exact_dtype=True,
)
def trunc(a):
    return prims.trunc(a)


# TODO: register this as a real ref/decomposition once TorchInductor supports complex!
def view_as_complex(self: TensorLikeType) -> TensorLikeType:
    input_dtype = self.dtype
    torch._check(
        utils.is_float_dtype(input_dtype),
        lambda: f"view_as_complex is only supported for floating point"
        f"tensors, but got a tensor of scalar type: {input_dtype}",
    )
    sizes = self.size()
    torch._check(
        len(sizes) != 0,
        lambda: "Input tensor must have one or more dimensions",
    )
    torch._check(
        sizes[-1] == 2,
        lambda: "Tensor must have a last dimension of size 2",
    )

    old_strides = self.stride()
    torch._check(
        old_strides[-1] == 1,
        lambda: "Tensor must have a last dimension with stride 1",
    )
    dims = old_strides[:-1]
    torch._check(
        builtins.all(stride % 2 == 0 for stride in dims),
        lambda: "Tensor must have a stride divisible by 2 for all but last dimension",
    )
    torch._check(
        self.storage_offset() % 2 == 0,
        lambda: "Tensor must have a storage_offset divisible by 2",
    )
    return prims.view_element_type(
        self, utils.corresponding_complex_dtype(input_dtype)
    ).squeeze(-1)


def _make_elementwise_binary_reference(
    type_promotion_kind,
    aten_op=infer_aten_op,
    name=None,
    has_out=True,
    supports_lhs_python_scalar=True,
    supports_rhs_python_scalar=True,
    supports_two_python_scalars=False,
    should_register_decomposition=True,
) -> Callable:
    def inner(prim: Callable):
        nonlocal aten_op, name
        if name is None:
            name = prim.__name__

        @wraps(prim)
        @elementwise_type_promotion_wrapper(
            type_promoting_args=("a", "b"),
            type_promotion_kind=type_promotion_kind,
        )
        def _ref(
            a: Union[Tensor, NumberType],
            b: Union[Tensor, NumberType],
        ) -> Tensor:
            torch._check_value(
                supports_lhs_python_scalar or not isinstance(a, Number),
                lambda: f"{name}: Received a lhs Python scalar to an elementwise binary "
                "operation that does not accept lhs scalars!",
            )
            torch._check_value(
                supports_rhs_python_scalar or not isinstance(b, Number),
                lambda: f"{name}: Received a rhs Python scalar to an elementwise binary "
                "operation that does not accept rhs scalars!",
            )
            torch._check_value(
                supports_two_python_scalars
                or not (isinstance(a, Number) and isinstance(b, Number)),
                lambda: f"{name}: Receive two Number inputs to an elementwise binary operation!",
            )
            a, b = _maybe_broadcast(a, b)
            output = prim(a, b)
            return handle_noncontiguous_outputs([a, b], output)

        if has_out:
            _ref = out_wrapper()(_ref)  # type: ignore[assignment]

        _ref.__name__ = name
        if aten_op is infer_aten_op:
            aten_op = utils.get_aten_op(prim, name)
        if aten_op is not None and should_register_decomposition:
            register_decomposition(aten_op)(_ref)

        return _ref

    return inner


# Add has its own implementation because it has an alpha argument
@register_decomposition(aten.add)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a", "b"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def add(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
    *,
    alpha: Optional[NumberType] = None,
):
    """
    Reference implementation of torch.add
    """

    a, b = _maybe_broadcast(a, b)

    if alpha is not None:
        dtype = a.dtype if isinstance(a, TensorLike) else b.dtype  # type: ignore[union-attr]
        python_type = utils.dtype_to_type(dtype)
        if python_type != bool and not utils.is_weakly_lesser_type(
            type(alpha), python_type
        ):
            msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
            raise ValueError(msg)
        if isinstance(b, TensorLike):
            b = prims.mul(b, alpha)
        else:
            b = b * alpha

    output = prims.add(a, b)
    return handle_noncontiguous_outputs([a, b], output)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def atan2(a, b):
    return prims.atan2(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def bitwise_and(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.bitwise_and(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def bitwise_left_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.shift_left(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def bitwise_or(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.bitwise_or(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def bitwise_right_shift(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.shift_right_arithmetic(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def bitwise_xor(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.bitwise_xor(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
    supports_lhs_python_scalar=False,
)
def copysign(
    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
):
    if isinstance(b, Number) and isinstance(a, Tensor):
        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
    elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
        msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
        raise RuntimeError(msg)
    return where(signbit(b), neg(abs(a)), abs(a))


# complex =  _make_elementwise_binary_reference(prims.complex, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)


@register_decomposition(aten.div)
@out_wrapper()
def div(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
    *,
    rounding_mode: Optional[str] = None,
):
    """
    Reference implementation of torch.div
    """
    if rounding_mode is None:
        return true_divide(a, b)
    elif rounding_mode == "trunc":
        return trunc_divide(a, b)
    elif rounding_mode == "floor":
        return floor_divide(a, b)
    else:
        msg = f"div expected rounding_mode to be one of None, 'trunc', or 'floor' but found {rounding_mode}."
        raise ValueError(msg)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    supports_lhs_python_scalar=False,
)
def eq(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.eq(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG,
)
def pow(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
) -> TensorLikeType:
    assert isinstance(a, TensorLikeType) or isinstance(b, TensorLikeType)

    if isinstance(b, Number):
        if b == 1.0:
            return a.clone()  # type: ignore[return-value,union-attr]
        elif b == 2.0:
            return a * a  # type: ignore[return-value]
        elif b == 0.5:
            return torch.sqrt(a)  # type: ignore[arg-type]
    elif isinstance(a, Number):
        if a == 1.0:
            return torch.fill(b, True)
        if a == 2.0 and (
            utils.is_float_dtype(b.dtype) or utils.is_complex_dtype(b.dtype)
        ):
            return torch.exp2(b)

    return prims.pow(a, b)


# Float power has its own implementation because it has unique type promotion.
# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def float_power(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
) -> Tensor:
    if isinstance(a, Number) and isinstance(b, Number):
        raise ValueError(
            "Receive two Number inputs to an elementwise binary operation!"
        )

    # Handles type promotion
    dtype = utils.get_higher_dtype(a, b)
    assert dtype is not None
    if utils.is_complex_dtype(dtype):
        dtype = torch.complex128
    else:
        dtype = torch.float64

    # Float power has the following contiguous cast behavior to be
    # consistent with its C++ impl
    a = _maybe_convert_to_dtype(a, dtype)
    b = _maybe_convert_to_dtype(b, dtype)

    a, b = _maybe_broadcast(a, b)
    return pow(a, b)


# >>> a = torch.tensor(-0.2500, dtype=torch.float64)
# tensor(-0.250000000000000, dtype=torch.float64)
#
# >>> b = torch.tensor(-0.0010, dtype=torch.float64)
# tensor(-0.001000000000000, dtype=torch.float64)
#
# Note: In this case, casting float to double will expand the float mantissa with zeros,
# while creating a double generates a distinct mantissa.
# >>> torch.tensor(-0.001).to(dtype=torch.float64)
# tensor(-0.001000000047497, dtype=torch.float64)
#
# Floor Division
# The difference is caused because torch.remainder(a, b) = -0.001.
#
# >>> torch.floor(torch.true_divide(a, b))
# tensor(250., dtype=torch.float64)
#
# >>> torch.div(a, b, rounding_mode='floor')
# tensor(249., dtype=torch.float64)
#
# Definition: a // b = (a - remainder(a, b)) / b
# >>> torch.true_divide(torch.sub(a, torch.remainder(a, b)), b)
# tensor(249., dtype=torch.float64)
#
# For reference, see CPython's implementation:
# https://github.com/python/cpython/blob/ace008c531dd685a30c1dd68f9b5ba35f20171cf/Objects/floatobject.c#L636


@_make_elementwise_binary_reference(
    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_two_python_scalars=True,
    should_register_decomposition=False,
)
def floor_divide(
    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
):
    # Wrap scalars because some references only accept tensor arguments.
    if isinstance(a, Number) and isinstance(b, Number):
        a = scalar_tensor(a)
        b = scalar_tensor(b)
    elif isinstance(b, Number) and isinstance(a, Tensor):
        b = scalar_tensor(b, dtype=a.dtype, device=a.device)
    elif isinstance(a, Number) and isinstance(b, Tensor):
        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
    elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
        if a.device == torch.device("cpu"):
            msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
            raise RuntimeError(msg)
        else:
            b = prims.device_put(b, device=a.device)

    assert isinstance(a, Tensor) and isinstance(b, Tensor)
    dtype = a.dtype
    if utils.is_float_dtype(dtype):
        return _floor_divide_float(a, b)
    elif utils.is_integer_dtype(dtype):
        return _floor_divide_integer(a, b)
    else:
        torch._check(False, lambda: f"{dtype} not supported for floor_divide")


def _floor_divide_integer(a: Tensor, b: Tensor) -> Tensor:
    a, b = _maybe_broadcast(a, b)

    if not a.dtype.is_signed:
        return prims.div(a, b)

    # Convert truncation to flooring:
    offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
    return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)


def _floor_divide_float(a: Tensor, b: Tensor) -> Tensor:
    mod = fmod(a, b)
    div = true_divide(sub(a, mod), b)

    # Ensure that the remainder has the same sign as denominator
    different_signed_inputs = bitwise_xor(lt(a, 0), lt(b, 0))
    non_zero_remainder = ne(mod, 0)
    mask = bitwise_and(non_zero_remainder, different_signed_inputs)
    div = where(mask, sub(div, 1), div)

    # Map quotient to nearest integer value
    floor_div = floor(div)
    mask = gt(sub(div, floor_div), 0.5)
    floor_div = where(mask, add(floor_div, 1), floor_div)

    basic_div = true_divide(a, b)
    zero_tensor = scalar_tensor(0, dtype=basic_div.dtype, device=basic_div.device)

    # If quotient is zero, copy signbit from true_divide quotient
    floor_div = where(ne(div, 0), floor_div, copysign(zero_tensor, basic_div))

    # If denominator is zero, then follow true_divide behavior
    return where(ne(b, 0), floor_div, basic_div)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def fmax(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.fmax(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def fmin(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.fmin(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=True,
)
def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.fmod(a, b)


@register_decomposition(aten.frexp)
@out_wrapper("mantissa", "exponent")
def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]:
    return torch.return_types.frexp(prims.frexp(self))


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def gcd(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.gcd(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    supports_lhs_python_scalar=False,
)
def ge(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.ge(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    supports_lhs_python_scalar=False,
)
def gt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.gt(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def heaviside(input: TensorLikeType, values: TensorLikeType) -> TensorLikeType:
    input_eq_zero = torch.eq(input, 0)
    input_lt_zero = torch.logical_or(torch.lt(input, 0), torch.isnan(input))
    zeros_and_ones = torch.where(input_lt_zero, 0, 1)
    output = torch.where(input_eq_zero, values, zeros_and_ones)
    return output


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def hypot(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.hypot(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def igamma(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.igamma(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def igammac(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.igammac(a, b)


def _check_close_args(
    name: str,
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float,
    atol: float,
) -> None:
    torch._check_value(
        a.dtype == b.dtype,
        lambda: f"{name}: Attempting to compare tensors of different dtypes {a.dtype} and {b.dtype}!",
    )
    torch._check(
        rtol >= 0,
        lambda: f"{name}: rtol must be greater than or equal to zero, but got {rtol}!",
    )
    torch._check(
        atol >= 0,
        lambda: f"{name}: atol must be greater than or equal to zero, but got {atol}!",
    )


# CompositeImplicitAutograd - don't register decomp
def isclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> TensorLikeType:
    _check_close_args(name="torch.isclose", a=a, b=b, rtol=rtol, atol=atol)

    close = eq(a, b)
    if equal_nan and (utils.is_float_dtype(a.dtype) or utils.is_complex_dtype(a.dtype)):
        close = logical_or(close, logical_and(isnan(a), isnan(b)))

    # Note: In case of zero tolerances the closeness inequality degenerates to an equality check.
    # In this case, the short-circuit prevents false positives as detailed in the paragraph below.
    if atol == 0 and rtol == 0:
        return close

    # Note [closeness error computation]
    # atol and rtol are provided as doubles, so the computation
    # rtol * other will produce a float or complex tensor.
    # When the difference (self - other) is compared to it then the
    # tensor representing the difference will also be cast to float or complex.
    # However, since (self - other) in uint8 is very likely to produce a
    # negative value, this moves the cast forward so the difference is
    # always computed in a float or complex type.
    # If the values of the integer tensors cannot be exactly represented
    # by the default scalar type then this may cause an incorrect result.
    if not utils.is_float_dtype(a.dtype) and not utils.is_complex_dtype(a.dtype):
        a = prims.convert_element_type(a, torch.get_default_dtype())
        b = prims.convert_element_type(b, torch.get_default_dtype())

    allowed_error = add(atol, abs(mul(b, rtol)))
    actual_error = abs(sub(a, b))

    # Computes finite closeness
    result = logical_or(
        close, logical_and(isfinite(actual_error), le(actual_error, allowed_error))
    )

    return result


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def lcm(a: TensorLikeType, b: TensorLikeType):
    dtype = a.dtype
    # promoting to int32 to maintain 100% consistency with C++ and to
    # prevent overflow in case of int8 and int16
    promote_to_int = dtype in (torch.int8, torch.int16)
    if promote_to_int:
        a = prims.convert_element_type(a, torch.int32)
        b = prims.convert_element_type(b, torch.int32)

    g = torch.gcd(a, b)
    # Avoid division by zero in case gcd(0, 0) == 0
    g = torch.where(g == 0, 1, g)
    res = torch.abs(prims.div(a, g) * b)
    return res if not promote_to_int else prims.convert_element_type(res, dtype)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    supports_lhs_python_scalar=False,
)
def le(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.le(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    # Nb. this implementation does not distribute the gradients evenly when a == b
    mask = torch.real(a) >= torch.real(b)
    max_ = torch.where(mask, a, b)
    min_ = torch.where(mask, b, a)
    inf_mask = torch.logical_and(
        torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b)
    )
    if utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype):
        # are you wondering what this bunch of codes are for? edge cases!
        neg_min_mask = torch.real(min_) < 0
        inf_vals = torch.where(
            neg_min_mask, min_, torch.log(torch.exp(min_) + torch.exp(max_))
        )
        non_nan_vals = torch.where(
            inf_mask, inf_vals, max_ + torch.log1p(torch.exp(min_ - max_))
        )
        # the type for full_like does not include tensor yet
        nan_mask = torch.isnan(min_)
        return torch.where(nan_mask, complex(float("nan"), float("nan")), non_nan_vals)  # type: ignore[call-overload]
    else:
        return torch.where(inf_mask, a, max_ + torch.log1p(torch.exp(min_ - max_)))


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def logaddexp2(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    torch._check(
        not (utils.is_complex_dtype(a.dtype) or utils.is_complex_dtype(b.dtype)),
        lambda: "logaddexp2 doesn't support complex dtypes",
    )
    # Nb. this implementation does not distribute the gradients evenly when a == b
    mask = a >= b
    max_ = torch.where(mask, a, b)
    min_ = torch.where(mask, b, a)
    inf_mask = torch.logical_and(torch.isinf(a), a == b)
    inv_log_2 = 1.0 / math.log(2)
    result = max_ + torch.log1p(torch.exp2(min_ - max_)) * inv_log_2
    return torch.where(inf_mask, a, result)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
def logical_and(a: TensorLikeType, b: TensorLikeType):
    if not utils.is_boolean_dtype(a.dtype):
        a = a != 0
    if not utils.is_boolean_dtype(b.dtype):
        b = b != 0
    return a & b


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL)
def logical_not(a: TensorLikeType):
    if not utils.is_boolean_dtype(a.dtype):
        return a == 0
    return ~a


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
def logical_or(a: TensorLikeType, b: TensorLikeType):
    if not utils.is_boolean_dtype(a.dtype):
        a = a != 0
    if not utils.is_boolean_dtype(b.dtype):
        b = b != 0
    return bitwise_or(a, b)


# TODO: skip unnecessary conversion of long to float
@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
)
def logical_xor(a: TensorLikeType, b: TensorLikeType):
    if not utils.is_boolean_dtype(a.dtype):
        a = a != 0
    if not utils.is_boolean_dtype(b.dtype):
        b = b != 0
    return a ^ b


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    supports_lhs_python_scalar=False,
)
def lt(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.lt(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def maximum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.maximum(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def minimum(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.minimum(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    supports_two_python_scalars=True,
)
def mul(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.mul(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
    supports_lhs_python_scalar=False,
)
def ne(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.ne(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
    supports_lhs_python_scalar=False,
    supports_rhs_python_scalar=False,
)
def nextafter(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.nextafter(a, b)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.remainder(a, b)


# reverse sub
@register_decomposition(aten.rsub)
@out_wrapper()
def rsub(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
    alpha: NumberType = 1,
):
    if isinstance(a, Number):
        msg = "Received a Number for the first argument, but expected a Tensor"
        raise ValueError(msg)

    return torch.sub(b, a, alpha=alpha)


# TODO: consider refactoring this with add impl
# sub has its own implementation because it has an alpha argument
@register_decomposition(aten.sub)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a", "b"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def sub(
    a: Union[TensorLikeType, NumberType],
    b: Union[TensorLikeType, NumberType],
    *,
    alpha: NumberType = 1,
):
    """
    Reference implementation of torch.sub
    """

    a, b = _maybe_broadcast(a, b)

    if isinstance(a, TensorLike) and isinstance(b, TensorLike):
        torch._check(
            not utils.is_boolean_dtype(a.dtype) and not utils.is_boolean_dtype(b.dtype),
            lambda: (
                "Subtraction, the `-` operator, with two bool tensors is not supported. "
                "Use the `^` or `logical_xor()` operator instead."
            ),
        )

    if alpha != 1:
        dtype = a.dtype if isinstance(a, TensorLike) else b.dtype  # type: ignore[union-attr]
        python_type = utils.dtype_to_type(dtype)
        if not utils.is_weakly_lesser_type(type(alpha), python_type):
            msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!"
            raise ValueError(msg)
        if isinstance(b, torch.Tensor):
            b = prims.mul(b, alpha)
        else:
            # Carefully not to use prims.mul if b is a scalar / symint.
            # prims.mul always returns a tensor,
            # which will mess with type promotion.
            b = b * alpha

    output = prims.sub(a, b)
    return handle_noncontiguous_outputs([a, b], output)


@_make_elementwise_binary_reference(
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
    name="true_divide",
    aten_op=None,  # CompositeImplicitAutograd
    supports_two_python_scalars=True,
)
def true_divide(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    return prims.div(a, b)


@register_decomposition(aten.xlogy)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a", "b"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
    torch._check(
        isinstance(a, TensorLike) or isinstance(b, TensorLike),
        lambda: 'Expected either argument a or b to be a Tensor"',
    )

    # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
    if isinstance(b, TensorLike) and isinstance(a, Number):
        a = scalar_tensor(a, dtype=b.dtype, device=b.device)
    elif isinstance(a, TensorLike) and isinstance(b, Number):
        b = scalar_tensor(b, dtype=a.dtype, device=a.device)

    # mypy: expected "Tensor"
    assert isinstance(a, TensorLike)
    assert isinstance(b, TensorLike)
    rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, torch.log(b)))
    return torch.where(torch.isnan(b), float("nan"), rhs)


@_make_elementwise_binary_reference(
    type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
    aten_op=None,  # CompositeImplicitAutograd
    supports_two_python_scalars=True,
)
def trunc_divide(
    a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
):
    dtype = utils.get_dtype(a)
    if utils.is_integer_dtype(dtype):
        return prims.div(a, b)

    return trunc(prims.div(a, b))


#
# Elementwise Ternary References
#


@register_decomposition(aten.addcdiv)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self", "tensor1", "tensor2"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
def addcdiv(
    self: TensorLikeType,
    tensor1: TensorLikeType,
    tensor2: TensorLikeType,
    *,
    value: NumberType = 1,
) -> TensorLikeType:
    """
    Reference implementation of torch.addcdiv
    """
    if value is not None:
        dtype = self.dtype  # no scalars allowed, see add
        python_type = utils.dtype_to_type(dtype)
        torch._check_value(
            utils.is_weakly_lesser_type(type(value), python_type),
            lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
        )

    return self + value * tensor1 / tensor2


@register_decomposition(aten.addcmul)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self", "tensor1", "tensor2"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def addcmul(
    self: TensorLikeType,
    tensor1: TensorLikeType,
    tensor2: TensorLikeType,
    *,
    value: NumberType = 1,
) -> TensorLikeType:
    """
    Reference implementation of torch.addcmul
    """
    if value is not None:
        dtype = self.dtype  # no scalars allowed, see add
        python_type = utils.dtype_to_type(dtype)
        torch._check_value(
            utils.is_weakly_lesser_type(type(value), python_type),
            lambda: f"value argument of type {type(value)} cannot be safely cast to type {python_type}!",
        )

    return self + value * tensor1 * tensor2


@register_decomposition(aten.clamp)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a", "min", "max"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def clamp(
    a: TensorLikeType,
    min: Optional[TensorOrNumberLikeType] = None,
    max: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
    # NOTE: grad behavior with implementation `where` is not consistent on `nan`
    if min is None and max is None:
        msg = "clamp called but both min and max are none!"
        raise ValueError(msg)
    if min is not None:
        a_isnan = torch.isnan(a)
        condition = torch.bitwise_or(torch.ge(a, min), a_isnan)  # type: ignore[arg-type]
        # we should also propagate `nan` coming from boundaries. However, that's
        # not necessary since `ge` would already `False` when either operands has
        # a `nan`. So this line below is redundant
        #   `condition = bitwise_and(condition, bitwise_not(isnan(min)))`
        a = torch.where(condition, a, min)  # type: ignore[arg-type]
    if max is not None:
        a_isnan = torch.isnan(a)
        # same as above, no need to adjust `nan` from `max`
        condition = torch.bitwise_or(torch.le(a, max), a_isnan)  # type: ignore[arg-type]
        a = torch.where(condition, a, max)  # type: ignore[arg-type]

    return a


@register_decomposition(aten.clamp_min)
@out_wrapper()
def clamp_min(
    self: TensorLikeType,
    min: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
    return torch.clamp(self, min=min)  # type: ignore[arg-type]


@register_decomposition(aten.clamp_max)
@out_wrapper()
def clamp_max(
    self: TensorLikeType,
    max: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
    return torch.clamp(self, max=max)  # type: ignore[arg-type]


#
# Conditional references
#


# https://pytorch.org/docs/stable/generated/torch.where.html
# TODO: implement alternate where
@register_decomposition(aten.where)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("a", "b"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def where(
    pred: Tensor,
    a: Optional[TensorOrNumberLikeType] = None,
    b: Optional[TensorOrNumberLikeType] = None,
):
    """ """

    if a is None or b is None:
        raise NotImplementedError

    utils.check_same_device(pred, a, b, allow_cpu_scalar_tensors=True)
    torch._check(
        pred.dtype is torch.bool,
        lambda: f"expected predicate to be bool, got {pred.dtype}",
    )

    pred, a, b = _maybe_broadcast(pred, a, b)
    return prims.where(pred, a, b)


#
# Data Movement References
#
@register_decomposition(aten.clone)
@out_wrapper()
def clone(
    a: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
) -> TensorLikeType:
    result = prims.clone(a, memory_format=memory_format)
    return result


def copy_to(a: Tensor, b: Tensor, *, allow_cross_device=True):
    if not allow_cross_device and a.device != b.device:
        msg = f"Attempting to copy from device {b.device} to device {a.device}, but cross-device copies are not allowed!"
        raise RuntimeError(msg)

    return prims.copy_to(a, b)


@register_decomposition(aten.item)
def item(a: TensorLikeType) -> NumberType:
    if a.numel() != 1:
        msg = f"Can't convert a tensor with {a.numel()} elements to a number!"
        raise ValueError(msg)

    # NOTE: explicit conversion is necessary for bool!
    # See https://github.com/pytorch/pytorch/issues/78071
    number_type = utils.dtype_to_type(a.dtype)
    return number_type(prims.item(a))


# fast path when `to` returns an alias to input. This mimics the same function in aten
def _to_will_alias(
    a: TensorLikeType,
    device: Optional[DeviceLikeType] = None,
    dtype: Optional[torch.dtype] = None,
    copy: Optional[bool] = None,
    layout: Optional[torch.layout] = None,
    memory_format: Optional[torch.memory_format] = None,
    pin_memory: Optional[bool] = False,
    non_blocking: bool = False,  # not using non_blocking
) -> bool:
    return (
        not copy
        and (device is None or a.device == device)
        and (dtype is None or a.dtype == dtype)
        and (layout is None or a.layout == layout)
        # is_pinned issue #84925
        # and (pin_memory is None or pin_memory == a.is_pinned())
        and (
            memory_format is None
            or memory_format == torch.preserve_format
            or utils.is_contiguous_for_memory_format(a, memory_format=memory_format)
        )
    )


@singledispatch
def _to_dispatch(*args, **kwargs):
    raise NotImplementedError


@_to_dispatch.register
def _to_device(
    device: torch.device,
    dtype: torch.dtype,
    non_blocking: bool = False,
    copy: bool = False,
    memory_format: Optional[torch.memory_format] = None,
) -> dict[str, Any]:
    kwargs = {
        "device": device,
        "dtype": dtype,
        "non_blocking": non_blocking,
        "copy": copy,
        "memory_format": memory_format,
    }
    return kwargs


@_to_dispatch.register
def _to_device_str(
    device: str,
    dtype: torch.dtype,
    non_blocking: bool = False,
    copy: bool = False,
    memory_format: Optional[torch.memory_format] = None,
) -> dict[str, Any]:
    kwargs = {
        "device": torch.device(device),
        "dtype": dtype,
        "non_blocking": non_blocking,
        "copy": copy,
        "memory_format": memory_format,
    }
    return kwargs


@_to_dispatch.register
def _to_dtype(
    dtype: torch.dtype,
    non_blocking: bool = False,
    copy: bool = False,
    memory_format: Optional[torch.memory_format] = None,
) -> dict[str, Any]:
    kwargs = {
        "dtype": dtype,
        "non_blocking": non_blocking,
        "copy": copy,
        "memory_format": memory_format,
    }
    return kwargs


@_to_dispatch.register
def _to_other(
    other: Tensor,
    non_blocking: bool = False,
    copy: bool = False,
    memory_format: Optional[torch.memory_format] = None,
) -> dict[str, Any]:
    device = other.device
    dtype = other.dtype
    layout = other.layout
    # is_pinned issue #84925
    # pin_memory = other.is_pinned()
    kwargs = {
        "device": device,
        "dtype": dtype,
        "layout": layout,
        "non_blocking": non_blocking,
        "copy": copy,
        "memory_format": memory_format,
    }
    return kwargs


# remove to_kwargs that is already present in `a`
def _canonicalize_to_arguments(a: Tensor, to_kwargs: dict):
    options_to_check = ["dtype", "device", "layout", "memory_format"]
    # "device" option could be passed a str instead torch.device
    if "device" in to_kwargs and isinstance(to_kwargs["device"], str):
        to_kwargs["device"] = torch.device(to_kwargs["device"])

    for kw in options_to_check:
        if kw in to_kwargs:
            if (
                (kw == "memory_format" and to_kwargs[kw] is torch.preserve_format)
                or (
                    kw == "device"
                    and to_kwargs[kw].type == a.device.type
                    and (
                        not to_kwargs[kw].index or to_kwargs[kw].index == a.device.index
                    )
                )
                or (
                    getattr(a, kw, None) == to_kwargs[kw]
                )  # this also handles {"memory_format": None}
            ):
                to_kwargs.pop(kw)


def to(a: TensorLikeType, *args, **kwargs) -> TensorLikeType:
    # handled dispatch via positional arguments
    if len(args) != 0:
        kwargs = _to_dispatch(*args, **kwargs)

    # TODO: is_pinned is not currently supported in refs or fake_tensor
    # https://github.com/pytorch/pytorch/issues/84925
    assert "pin_memory" not in kwargs
    _canonicalize_to_arguments(a, kwargs)

    if _to_will_alias(a, **kwargs):
        return a

    copy = kwargs.pop("copy") if "copy" in kwargs else False
    non_blocking = kwargs.pop("non_blocking") if "non_blocking" in kwargs else False

    # short-circuit to `prims.convert_element_type` when `to` is just a dtype change
    if (
        (copy or (kwargs.get("dtype", a.dtype) != a.dtype))
        and (not non_blocking)
        and ("memory_format" not in kwargs)
        and ("device" not in kwargs)
        and ("layout" not in kwargs)
        # is_pinned issue #84925
        # and ("pin_memory" not in kwargs)
    ):
        return prims.convert_element_type(a, kwargs.get("dtype", a.dtype))

    result = torch.empty_like(a, **kwargs)
    # TODO: non_blocking should be handled by `copy_to`
    copy_to(result, a)
    return result


#
# Reduction references
#


def _reduction(
    a: TensorLikeType,
    prim: Callable,
    *,
    has_identity: bool = True,
    accepts_dim_tuple: bool = True,  # to handle min/argmin that accept single dim only
    dims: Optional[DimsType] = None,
    keepdims: bool = False,
    dtype: Optional[torch.dtype] = None,  # should be specified for ops that support it
    out: Optional[Tensor] = None,
    output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
) -> TensorLikeType:  # it is usually SAME, but I want
    # ref writers to actually think about what to put here
    assert isinstance(a, TensorLike)
    if a.ndim > 64:
        raise RuntimeError(
            f"Received a tensor with {a.ndim} dimensions, but only tensors with up to 64 dims are supported!"
        )

    if out is not None:
        assert isinstance(out, TensorLike)
        if dtype is not None:
            # TODO - this is true for eager mode currently, but it's wrong behavior for complex norms
            if dtype != out.dtype:
                raise RuntimeError(
                    "dtype argument and out dtype must match in reduction"
                )
    if not accepts_dim_tuple:
        assert dims is None or isinstance(dims, Dim)
    if isinstance(dims, Dim):
        dims = (dims,)  # type: ignore[assignment]
    dims = utils.reduction_dims(a.shape, dims)
    if not has_identity:
        valid_shape = a.ndim == 0 or builtins.all(a.shape[i] for i in dims)
        if not valid_shape:
            raise RuntimeError(
                "reducing over zero-size dimension for reduction operation without identity"
            )
    computation_dtype, result_dtype = utils.reduction_dtypes(
        a, output_dtype_kind, dtype
    )
    a = _maybe_convert_to_dtype(a, computation_dtype)  # type: ignore[method-assign]
    result = prim(a, dims)
    if keepdims:
        output_shape = [a.shape[i] if i not in dims else 1 for i in range(a.ndim)]
        broadcast_dims = [i for i in range(a.ndim) if i not in dims]
        result = prims.broadcast_in_dim(result, output_shape, broadcast_dims)

    if out is not None:
        assert result_dtype is not None
        if dtype is not None and result_dtype != out.dtype:
            raise RuntimeError(
                "Expected the dtype of reduction result and out to match"
            )
        out = _maybe_resize_out(out, result.shape)
        return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]

    if result.dtype != result_dtype and result_dtype is not None:
        result = prims.convert_element_type(result, result_dtype)

    return result


def _make_copy_from_view(fn):
    """
    Given a view function (e.g. torch.diagonal) generates its copy variant (e.g. torch.diagonal_copy)
    """
    aten_fn = getattr(aten, fn.__name__)
    annotations = getattr(fn, "__annotations__", {})
    fn = out_wrapper()(aten_fn)

    @wraps(fn)
    def _fn(*args, out=None, **kwargs):
        result = fn(*args, out=out, **kwargs)
        if out is not None:
            return result

        return pytree.tree_map(
            lambda x: x.clone(memory_format=torch.contiguous_format),
            result,
        )

    copy_name = f"{fn.__name__}_copy"
    _fn.__name__ = copy_name
    _fn.__annotations__.update(annotations)
    register_decomposition(getattr(aten, copy_name))(_fn)
    return _fn


@register_decomposition(aten.all)
@out_wrapper()
def all(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
) -> TensorLikeType:
    result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))

    if a.dtype == torch.uint8:
        result = result.to(dtype=torch.uint8)

    return result


@register_decomposition(aten.any)
@out_wrapper()
def any(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
) -> TensorLikeType:
    a_ = _maybe_convert_to_dtype(a, torch.bool)
    if isinstance(dim, (list, tuple)) and len(dim) == 0:
        result = a_.clone()
    else:
        result = a_.sum(dim=dim, keepdim=keepdim).ne(False)

    # Preserves uint8 -- probably a legacy mask thing
    if a.dtype is torch.uint8:
        return prims.convert_element_type(result, torch.uint8)

    return result


@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
def sum(
    a: TensorLikeType,
    dim: Union[Optional[int], Optional[list[int]]] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    if dtype is None:
        if out is not None:
            dtype = out.dtype
        elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
            dtype = torch.int64
        else:
            dtype = a.dtype
    # reduces over all dimensions if dim=() is passed
    if dim == () or dim == []:
        dim = None
    return _reduction(
        a,
        prims.sum,
        dims=dim,
        keepdims=keepdim,
        dtype=dtype,
        out=out,
        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
    )


def sum_to_size(
    a: Tensor,
    *shape,
) -> Tensor:
    shape = utils.extract_shape_from_varargs(shape, validate=False)
    torch._check(
        utils.is_expandable_to(shape, a.shape),
        lambda: f'sum_to_size: size "{shape}" is not expandable to size "{a.shape}"',
    )
    # In ATen scalar tensors are sent through sum and the result is returned as
    # type promoted
    if utils.is_same_shape(shape, a.shape) and len(shape) > 0:
        return prims.view_of(a)
    leading_dims = a.ndim - len(shape)
    reduce_dims = tuple(range(leading_dims)) + tuple(
        i
        for i in range(leading_dims, len(shape))
        if shape[i - leading_dims] == 1 and a.shape[i] != 1
    )
    return torch.sum(a, dim=reduce_dims, keepdim=True, dtype=None)


@register_decomposition(aten.prod)
def prod(
    a: TensorLikeType,
    dim: Union[Optional[int], Optional[list[int]]] = None,
    keepdim: bool = False,
    *,
    dtype=None,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    if dtype is None:
        if out is not None:
            dtype = out.dtype
        elif utils.is_boolean_dtype(a.dtype) or utils.is_integer_dtype(a.dtype):
            dtype = torch.int64
        else:
            dtype = a.dtype
    # reduces over all dimensions if dim=() is passed
    if dim == () or dim == []:
        dim = None
    return _reduction(
        a,
        prims.prod,
        dims=dim,
        keepdims=keepdim,
        dtype=dtype,
        out=out,
        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
    )


@register_decomposition(aten.amin)
def amin(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    # reduces over all dimensions if dim=() is passed
    if dim == () or dim == []:
        dim = None

    return _reduction(
        a,
        prims.amin,
        dims=dim,
        keepdims=keepdim,
        dtype=None,
        out=out,
        has_identity=False,
        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
    )


@register_decomposition(aten.amax)
def amax(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    # reduces over all dimensions if dim=() is passed
    if dim == () or dim == []:
        dim = None

    return _reduction(
        a,
        prims.amax,
        dims=dim,
        keepdims=keepdim,
        dtype=None,
        out=out,
        has_identity=False,
        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.SAME,
    )


def _dim_var_dispatch(dim=None, unbiased=None):
    # There's the following overload of torch.var:
    # var(Tensor self, bool unbiased=True) -> (Tensor, Tensor)
    # We need to explicitly convert bool dims to unbiased arg
    if unbiased is None and isinstance(dim, bool):
        unbiased = dim
        dim = None
    return dim, unbiased


@register_decomposition(aten.var)
@out_wrapper()
def var(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    unbiased: Optional[bool] = None,
    keepdim: bool = False,
    *,
    correction: Optional[NumberType] = None,
) -> TensorLikeType:
    dim, unbiased = _dim_var_dispatch(dim, unbiased)
    correction = utils.set_correction(unbiased, correction)
    # reduces over all dimensions if dim=() is passed
    if dim == () or dim == []:
        dim = None

    result = _reduction(
        a,
        partial(prims.var, correction=correction),
        dims=dim,
        keepdims=keepdim,
        dtype=None,
        out=None,
        has_identity=True,
        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
    )
    return result


@register_decomposition(aten.std)
@out_wrapper()
def std(
    a: TensorLikeType,
    dim: Union[Optional[int], Optional[list[int]]] = None,
    unbiased: Optional[bool] = None,
    keepdim: bool = False,
    *,
    correction: Optional[NumberType] = None,
) -> TensorLikeType:
    dim, unbiased = _dim_var_dispatch(dim, unbiased)
    correction = utils.set_correction(unbiased, correction)

    opmath_dtype, dtype = utils.reduction_dtypes(
        a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
    )
    a = _maybe_convert_to_dtype(a, opmath_dtype)
    a_var = torch.var(a, dim, correction=correction, keepdim=keepdim)
    a_std = torch.sqrt(a_var)
    assert dtype is not None
    return _maybe_convert_to_dtype(a_std, dtype)


@register_decomposition(aten.mean)
def mean(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype=None,
    out=None,
) -> TensorLikeType:
    # reduces over all dimensions if dim=() is passed
    if dim == () or dim == []:
        dim = None
    orig_dtype = dtype
    if dtype is None:
        dtype = a.dtype
    result = _reduction(
        a,
        prims.sum,
        dims=dim,
        keepdims=keepdim,
        dtype=dtype,
        out=None,
        output_dtype_kind=REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE,
    )
    torch._check(
        utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
        lambda: (
            f"mean(): could not infer output dtype. "
            f"{'Input' if orig_dtype is None else 'Optional'} dtype must be either "
            f"a floating point or complex dtype. Got: {dtype}"
        ),
    )
    if isinstance(dim, Dim):
        dim = (dim,)  # type: ignore[assignment]
    dims = utils.reduction_dims(a.shape, dim)  # type: ignore[arg-type]
    nelem = 1 if a.ndim == 0 else reduce(operator.mul, (a.shape[i] for i in dims), 1)
    result = true_divide(result, nelem)
    result_dtype = a.dtype if dtype is None else dtype
    result = _maybe_convert_to_dtype(result, result_dtype)  # type: ignore[method-assign]
    if out is not None:
        assert isinstance(out, TensorLike)
        out = _maybe_resize_out(out, result.shape)
        return _safe_copy_out(copy_from=result, copy_to=out)  # type: ignore[arg-type]
    return result


@register_decomposition(aten.std_mean)
@out_wrapper("out0", "out1")
def std_mean(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    *,
    unbiased: Optional[bool] = None,
    keepdim: bool = False,
    correction: Optional[NumberType] = None,
):
    dim, unbiased = _dim_var_dispatch(dim, unbiased)
    correction = utils.set_correction(unbiased, correction)
    opmath_dtype, dtype = utils.reduction_dtypes(
        a, REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
    )
    original_dtype = a.dtype
    a = _maybe_convert_to_dtype(a, opmath_dtype)
    a_var, a_mean = torch.var_mean(a, dim, correction=correction, keepdim=keepdim)
    a_std = torch.sqrt(a_var)
    assert dtype is not None
    return (
        _maybe_convert_to_dtype(a_std, dtype),
        _maybe_convert_to_dtype(a_mean, original_dtype),
    )


@register_decomposition(aten.var_mean)
@out_wrapper("out0", "out1")
def var_mean(
    a: TensorLikeType,
    dim: Optional[DimsType] = None,
    unbiased: Optional[bool] = None,
    keepdim: bool = False,
    *,
    correction: Optional[NumberType] = None,
):
    dim, unbiased = _dim_var_dispatch(dim, unbiased)
    v = var(a, dim, unbiased, keepdim, correction=correction)
    m = mean(a, dim, keepdim)
    return v, m


@register_decomposition(aten.addr)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self", "vec1", "vec2"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def addr(
    self: TensorLikeType,
    vec1: TensorLikeType,
    vec2: TensorLikeType,
    *,
    beta: NumberType = 1,
    alpha: NumberType = 1,
) -> TensorLikeType:
    torch._check(
        vec1.ndim == 1,
        lambda: f"addr: Expected 1-D argument vec1, but got {vec1.ndim}-D",
    )
    torch._check(
        vec2.ndim == 1,
        lambda: f"addr: Expected 1-D argument vec2, but got {vec2.ndim}-D",
    )
    for arg, arg_name in ((alpha, "alpha"), (beta, "beta")):
        if isinstance(arg, bool):
            torch._check(
                utils.is_boolean_dtype(self.dtype)
                and utils.is_boolean_dtype(vec1.dtype)
                and utils.is_boolean_dtype(vec2.dtype),
                lambda: f"Boolean {arg_name} only supported for Boolean results.",
            )
    self = self.expand(vec1.shape[0], vec2.shape[0])
    if utils.is_boolean_dtype(self.dtype):
        # Integers are accepted for booleans
        torch._check(
            is_weakly_lesser_type(type(beta), int),
            lambda: f"expected bool/int beta but got {type(beta)}",
        )
        torch._check(
            is_weakly_lesser_type(type(alpha), int),
            lambda: f"expected bool/int alpha but got {type(beta)}",
        )
        if not beta:
            return torch.outer(vec1, vec2) if alpha else torch.full_like(self, False)
        else:
            return torch.logical_or(
                self,
                torch.outer(vec1, vec2) if alpha else torch.full_like(self, False),
            )
    else:
        torch._check(
            is_weakly_lesser_type(type(beta), dtype_to_type(self.dtype)),
            lambda: f"cannot safely convert {type(beta)} to {self.dtype}",
        )
        torch._check(
            is_weakly_lesser_type(type(alpha), dtype_to_type(self.dtype)),
            lambda: f"cannot safely convert {type(alpha)} to {self.dtype}",
        )
        if beta == 0:
            # This means NaNs from self are dropped if beta is zero
            return alpha * torch.outer(vec1, vec2)
        else:
            return beta * self + alpha * torch.outer(vec1, vec2)


# CompositeImplicitAutograd - don't register decomp
def atleast_1d(
    arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
    """Reference implementation of :func:`torch.atleast_1d`."""
    if not args and isinstance(arg, collections.abc.Sequence):
        args_ = arg
    else:
        assert not isinstance(arg, collections.abc.Sequence)
        args_ = (arg,) + args
    res = tuple(a if a.ndim >= 1 else unsqueeze(a, 0) for a in args_)
    return res if len(res) > 1 else res[0]


# Helper function with assert to avoid MyPy error
# of incompatible type passed to unsqueeze
def _unsqueeze_atleast(
    at_least_fn: Callable, dim: int, arg: TensorLikeType
) -> TensorLikeType:
    arg_ = at_least_fn(arg)
    assert isinstance(arg_, TensorLike)
    return unsqueeze(arg_, dim)


# CompositeImplicitAutograd - don't register decomp
def atleast_2d(
    arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
    """Reference implementation of :func:`torch.atleast_2d`."""
    if not args and isinstance(arg, collections.abc.Sequence):
        args_ = arg
    else:
        assert not isinstance(arg, collections.abc.Sequence)
        args_ = (arg,) + args
    unsqueeze_atleast_1d = partial(_unsqueeze_atleast, atleast_1d, 0)
    res = tuple(a if a.ndim >= 2 else unsqueeze_atleast_1d(a) for a in args_)
    return res if len(res) > 1 else res[0]


# CompositeImplicitAutograd - don't register decomp
def atleast_3d(
    arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
    """Reference implementation of :func:`torch.atleast_3d`."""
    if not args and isinstance(arg, collections.abc.Sequence):
        args_ = arg
    else:
        assert not isinstance(arg, collections.abc.Sequence)
        args_ = (arg,) + args
    unsqueeze_atleast_2d = partial(_unsqueeze_atleast, atleast_2d, -1)
    res = tuple(a if a.ndim >= 3 else unsqueeze_atleast_2d(a) for a in args_)
    return res if len(res) > 1 else res[0]


def as_strided(
    a: TensorLikeType,
    size: ShapeType,
    stride: StrideType,
    storage_offset: Optional[int] = None,
) -> TensorLikeType:
    storage_offset_int = (
        storage_offset if storage_offset is not None else a.storage_offset()
    )
    return prims.as_strided(a, size, stride, storage_offset_int)


@register_decomposition(aten.as_strided_scatter)
@out_wrapper()
def as_strided_scatter(
    input: TensorLikeType,
    src: TensorLikeType,
    size: ShapeType,
    stride: StrideType,
    storage_offset: Optional[int] = None,
) -> TensorLikeType:
    storage_offset_int = 0 if storage_offset is None else storage_offset
    return prims.as_strided_scatter(input, src, size, stride, storage_offset_int)


def broadcast_shapes(*shapes) -> ShapeType:
    return torch.Size(_broadcast_shapes(*shapes))


@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
def broadcast_tensors(*tensors) -> list[TensorLikeType]:
    if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
        tensors = tensors[0]
    return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))


# CompositeImplicitAutograd - don't register decomp
def broadcast_to(a: TensorLikeType, size: ShapeType) -> TensorLikeType:
    start = len(size) - len(a.shape)
    dims = tuple(range(start, len(a.shape) + start))
    return prims.broadcast_in_dim(a, size, dims)


@register_decomposition(aten.cat)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("tensors",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH,
)
def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
    def cat_compute_output_memory_format(inputs):
        format = None
        for t in inputs:
            f = utils.suggest_memory_format(t)
            if f == torch.contiguous_format:
                return f
            if format is not None and format != f:
                return torch.contiguous_format
            format = f
        assert format is not None
        return format

    if len(tensors) == 0:
        msg = "cat expects at least one tensor, but received zero!"
        raise ValueError(msg)

    for tensor in tensors:
        assert isinstance(tensor, TensorLike)

    utils.check_same_device(*tensors, allow_cpu_scalar_tensors=False)

    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

    # This is a bit tricky.  Naively, you would expect to just pick one
    # arbitrary tensor and check that all tensors match this tensor.  However,
    # there is legacy behavior which says that if you have a 1-D empty tensor
    # (0,), this is permissible.  So you can't assume that all the tensors
    # have same dimensionality, and you can't assume that the first tensor is
    # the correct stencil.
    #
    # We'll implement this in a few passes.  First, we will try to infer the
    # ndim of the cat output.  If this ndim != 1, then we know that all ndim =
    # 1 inputs must be empty, or are errors.  If this ndim == 1, then life
    # is easy (the legacy special case coincides with regular handling).
    #
    # NB: The regular implementation of cat just filters out empty inputs,
    # but we do it slightly different here for better handling for unbacked
    # SymInts

    example = None
    for i, t in enumerate(tensors):
        if example is None:
            if t.ndim != 1:
                example = t
        else:
            if t.ndim != 1:
                torch._check(
                    t.ndim == example.ndim,
                    lambda: "Number of dimensions of tensors must match.  "
                    f"Expected {example.ndim}-D tensors, but got {t.ndim}-D for "
                    f"tensor number {i} in the list",
                )

    if example is None:
        # example is None if everything is 1-D.  If so, just arbitrarily pick
        # the first one
        example = tensors[0]

    shape = example.shape
    filtered = []
    for tensor_idx, tensor in enumerate(tensors):
        if len(shape) != len(tensor.shape):
            assert tensor.ndim == 1  # we've already checked this above
            # Don't suggest the legacy behavior in the error message
            torch._check(
                # NB: it is not enough to simply assert that tensor.shape[0] == 0;
                # this MUST be true even under guard size oblivious.
                # Effectively, we must actually know that the shape is zero,
                # passing an unbacked SymInt which we will defer a runtime
                # assert on won't cut it.  This is a policy decision (size
                # oblivious semantics say that u0 tensors never are inferred
                # to be zero size, even if they must be that for the cat to go
                # through), and is load bearing for our Inductor lowerings
                # (which assume that size oblivious tests are OK to determine
                # if a shape is permissibly zero.)
                guard_size_oblivious(tensor.shape[0] == 0),
                lambda: f"Number of dimensions of tensors must match.  "
                f"Expected {example.ndim}-D tensors, but got 1-D for "
                f"tensor number {tensor_idx} in the list",
            )
        else:
            # Remove inputs that are 1-D, zero size
            if tensor.ndim == 1 and guard_size_oblivious(tensor.shape[0] == 0):
                continue
            # Don't bother checking size match, prims.cat will handle it
            filtered.append(tensor)

    memory_format = cat_compute_output_memory_format(tensors)

    if len(filtered) == 0:
        t = tensors[0]

        # TODO: fix this to work with meta tensors
        try:
            # BUG? This looks like it wants to call builtins.any() but is
            # actually calling .any() (in this file). Changing to builtins.any()
            # causes tests to fail:
            # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/test_ops.py -k \
            #   TestFakeTensorCUDA.test_fake_crossref_backward_amp_cat_cuda_float32
            requires_grad = bool(any(x.requires_grad for x in tensors))  # type: ignore[arg-type]
        except Exception:
            requires_grad = False  # type: ignore[assignment]

        return empty(
            (0,),
            dtype=t.dtype,
            device=t.device,
            requires_grad=requires_grad,
            memory_format=memory_format,
        )

    dim = utils.canonicalize_dim(filtered[0].ndim, dim)
    utils.validate_idx(filtered[0].ndim, dim)

    return prims.cat(filtered, dim).clone(memory_format=memory_format)


# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def column_stack(tensors: TensorSequenceType) -> TensorLikeType:
    aligned_tensors = tuple(
        x if x.ndim > 1 else x.reshape((x.numel(), 1)) for x in tensors
    )
    return cat(aligned_tensors, 1)


def conj(input: TensorLikeType) -> TensorLikeType:
    if not utils.is_complex_dtype(input.dtype):
        return input
    if input.is_sparse:
        return torch.conj_physical(input)
    return prims.conj(input)


# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
@register_decomposition(aten.constant_pad_nd)
@out_wrapper()
def constant_pad_nd(
    input: TensorLikeType, pad: list[int], value: NumberType = 0
) -> TensorLikeType:
    torch._check(
        len(pad) % 2 == 0,
        lambda: f"Length of pad must be even but instead it equals {len(pad)}",
    )

    input_sizes = input.shape
    l_inp = len(input_sizes)

    l_pad = len(pad) // 2
    l_diff = l_inp - l_pad

    torch._check(
        l_inp >= l_pad,
        lambda: "Length of pad should be no more than twice the number of "
        f"dimensions of the input. Pad length is {len(pad)} while the input has "
        f"{l_inp} dimensions.",
    )

    c_input = input
    for i in range(l_diff, l_inp):
        pad_idx = 2 * (l_inp - i - 1)
        if pad[pad_idx] < 0:
            c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])

        if pad[pad_idx + 1] < 0:
            c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])

    # If all the pads are negative we can return the result.
    # Avoid early exiting if all pads = 0 to prevent specialization on export.
    # During export, raw if statements are specialized on the input, meaning
    # that we lose a branch depending on the example input used to export.
    # Here, this is either the case where all pads = 0, or the case where at
    # least one pad > 0 and the rest are >= 0.
    # Avoiding the early exit when all pads = 0 ensures we can export
    # constant_pad_nd for cases when all pads >= 0.
    # Note: if any pads are negative, this code specializes due to the if statements above.
    if builtins.all(p < 0 for p in pad):
        return c_input.clone()

    new_shape = list(input_sizes[:l_diff])

    for i in range(l_pad):
        pad_idx = len(pad) - ((i + 1) * 2)
        new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
        torch._check(
            new_dim > 0,
            lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
            f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
            f"which is invalid. Check dimension {l_diff + i} of your input.",
        )
        new_shape.append(new_dim)

    memory_format = utils.suggest_memory_format(input)
    output = torch.empty(
        new_shape,
        dtype=input.dtype,
        device=input.device,
        requires_grad=input.requires_grad,
        memory_format=memory_format,
    )

    if value == 0 and input.dtype == torch.bool:
        value = False
    # torch.fill isn't typed to allow complex values
    output = torch.fill(output, value)  # type: ignore[arg-type]

    c_output = output
    for i in range(l_diff, l_inp):
        pad_idx = 2 * (l_inp - i - 1)
        if pad[pad_idx] >= 0:
            c_output = c_output.narrow(
                i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
            )
        if pad[pad_idx + 1] >= 0:
            c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])

    prims.copy_to(c_output, c_input)
    return output


def contiguous(
    a: Tensor, *, memory_format: torch.memory_format = torch.contiguous_format
) -> Tensor:
    torch._check(
        memory_format != torch.preserve_format,
        lambda: "preserve memory format is unsupported by the contiguous operator",
    )

    if utils.is_contiguous_for_memory_format(a, memory_format=memory_format):
        return a

    return torch.clone(a, memory_format=memory_format)


@out_wrapper()
def dstack(tensors: TensorSequenceType) -> TensorLikeType:
    torch._check(len(tensors) > 0, lambda: "dstack expects a non-empty TensorList")
    aligned_tensors = atleast_3d(*tensors)
    return cat(aligned_tensors, 2)


@register_decomposition(aten.expand)
def expand(a: Tensor, *shape) -> Tensor:
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

    # NOTE: cannot use utils.extract_shape_from_varargs here
    # because that also validates the shape, but the shape
    # given to expand may be "invalid"
    if len(shape) == 1 and isinstance(shape[0], Sequence):
        shape = tuple(shape[0])

    torch._check(
        len(shape) >= len(a.shape),
        lambda: "expand: the requested shape has too few dimensions!",
    )

    offset = len(shape) - len(a.shape)
    shape_ = list(shape)
    for idx, x in enumerate(a.shape):
        offset_idx = idx + offset
        requested_length = shape[offset_idx]
        torch._check(
            guard_size_oblivious(requested_length == x)
            or guard_size_oblivious(x == 1)
            or requested_length == -1,
            lambda: f"expand: attempting to expand a dimension of length {x}!",
        )

        shape_[offset_idx] = requested_length if requested_length != -1 else x

    # At this point shape must be valid
    utils.validate_shape(shape_)

    return prims.broadcast_in_dim(
        a, shape_, tuple(range(offset, len(a.shape) + offset))
    )


# CompositeImplicitAutograd - don't register decomp
def expand_as(a: Tensor, b: Tensor) -> Tensor:
    return a.expand(b.shape)


def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]:
    if chunks <= 0:
        msg = f"Expected at least one chunk, but got {chunks}!"
        raise ValueError(msg)

    dim = utils.canonicalize_dim(a.ndim, dim)
    length = a.shape[dim]
    chunk_size = math.ceil(length / chunks)
    full_chunks = math.floor(length / chunk_size)
    tail_chunk_size = length % chunk_size

    result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)]

    if tail_chunk_size != 0:
        result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))

    return tuple(result)


# Note: flatten, unlike other shape operators, returns the input tensor on a no-op (unless
# a 0D tensor is flattened, in which case it's returned in 1D)
# CompositeImplicitAutograd - don't register decomp
def flatten(a: TensorLikeType, start_dim: int = 0, end_dim: int = -1) -> TensorLikeType:
    start_dim = utils.canonicalize_dim(a.ndim, start_dim)
    end_dim = utils.canonicalize_dim(a.ndim, end_dim)

    # Short-circuits on no-op
    if start_dim == end_dim and a.ndim != 0:
        return a

    # Tries to take a view
    # TODO: we could look at directing collapse_view to skip its meta function here (unsafe_collapse_view)
    new_shape, _new_strides = prims._collapse_view_helper(a, start_dim, end_dim)
    if new_shape is not None:
        return prims.collapse_view(a, start_dim, end_dim)

    # Makes a copy if it can't make a view
    return prims.collapse(a, start_dim, end_dim)


@register_decomposition(aten.flip)
@out_wrapper()
def flip(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
    if not isinstance(dims, tuple) and not isinstance(dims, list):
        raise ValueError("dims has to be a sequence of ints")
    dims = utils.canonicalize_dims(a.ndim, dims)  # type: ignore[assignment]
    utils.validate_no_repeating_dims(dims)
    return prims.rev(a, dims)


# CompositeImplicitAutograd - don't register decomp
def fliplr(a: TensorLikeType) -> TensorLikeType:
    if a.ndim < 2:
        raise RuntimeError("Input must be >= 2-d.")

    return flip(a, (1,))


# CompositeImplicitAutograd - don't register decomp
def flipud(a: TensorLikeType) -> TensorLikeType:
    if a.ndim < 1:
        raise RuntimeError("Input must be >= 1-d.")

    return flip(a, (0,))


# CompositeImplicitAutograd - don't register decomp
def narrow(
    a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int
) -> TensorLikeType:
    # Supports Tensor overload that was added for XLA:
    # https://github.com/pytorch/pytorch/issues/31558
    if isinstance(start, TensorLike):
        torch._check(
            start.dim() == 0 and utils.is_integer_dtype(start.dtype),
            lambda: "start must be an 0-dim integral Tensor.",
        )
        start = start.item()  # type: ignore[assignment]
    start = cast(int, start)
    torch._check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.")
    torch._check(length >= 0, lambda: "narrow(): length must be non-negative.")
    dim = utils.canonicalize_dim(a.ndim, dim)
    dim_length = a.size(dim)
    torch._check_with(
        IndexError,
        -dim_length <= start and start <= dim_length,
        lambda: f"start out of range (expected to be in range of [{-dim_length}, {dim_length}], but got {start})",
    )
    if start < 0:
        start = start + dim_length
    torch._check(
        start <= dim_length - length,
        lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).",
    )
    new_shape = list(a.shape)
    new_shape[dim] = length
    return a.as_strided(
        new_shape, a.stride(), a.storage_offset() + a.stride(dim) * start
    )


def _normalize(
    a: Tensor, norm_dims: DimsType, eps: float
) -> tuple[Tensor, Tensor, Tensor]:
    """Computes mean and 1/std of a tensor along norm_dims.

    Used as a helper function for normalization layers.

    Args:
        a (Tensor): input tensor
        norm_dims (DimsType): dimensions to normalize over
        eps (float): epsilon for numerical stability

    Returns:
        out (Tensor): normalized tensor.
        mean (Tensor): mean of the tensor along norm_dims.
        rstd (Tensor): 1/std of the tensor along norm_dims.
    """
    norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
    computation_dtype = utils.get_computation_dtype(a.dtype)
    a_acc = _maybe_convert_to_dtype(a, computation_dtype)
    assert isinstance(a_acc, TensorLike)  # to avoid mypy error for var_mean
    biased_var, mean = torch.var_mean(
        a_acc, dim=norm_dims, unbiased=False, keepdim=True
    )
    rstd = torch.rsqrt(biased_var + eps)
    out = (a_acc - mean) * rstd
    return out, mean, rstd


# add all specified dimensions
def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType:
    for dim in sorted(dimensions):
        x = torch.unsqueeze(x, dim)
    return x


@register_decomposition(aten.native_group_norm.default)
def native_group_norm(
    input: Tensor,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    batch_size: int,
    num_channels: int,
    flattened_inner_size: int,
    num_groups: int,
    eps: float,
) -> tuple[Tensor, Tensor, Tensor]:
    torch._check(
        input.ndim >= 2,
        lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
    )
    torch._check(
        num_channels % num_groups == 0,
        lambda: "Expected number of channels in input to be divisible by num_groups, "
        + f"but got input of shape {input.shape} and num_groups = {num_groups}",
    )

    # num_channels / num_groups and flattened inner dimension are the reduction axes
    reduction_dims = [2, 3]
    input_reshaped = torch.reshape(
        input,
        [batch_size, num_groups, num_channels // num_groups, flattened_inner_size],
    )
    out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps)
    out = out.view(input.shape)

    broadcast_dims = [0] + list(range(2, input.ndim))
    unsqueeze_bias = None
    if bias is not None:
        unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims)
    unsqueeze_weight = None
    if weight is not None:
        unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims)

    if unsqueeze_weight is not None:
        out = out * unsqueeze_weight
    if unsqueeze_bias is not None:
        out = out + unsqueeze_bias

    out = _maybe_convert_to_dtype(out, input.dtype)  # type: ignore[assignment]
    mean = _maybe_convert_to_dtype(mean, input.dtype)  # type: ignore[assignment]
    rstd = _maybe_convert_to_dtype(rstd, input.dtype)  # type: ignore[assignment]

    # remove broadcast dimensions from mean and rstd
    mean = torch.squeeze(mean, reduction_dims)
    rstd = torch.squeeze(rstd, reduction_dims)
    return (out, mean, rstd)


@register_decomposition(aten.native_layer_norm)
@out_wrapper("out0", "out1", "out2")
def native_layer_norm(
    input: Tensor,
    normalized_shape: ShapeType,
    weight: Optional[Tensor],
    bias: Optional[Tensor],
    eps: float,
) -> tuple[Tensor, Tensor, Tensor]:
    normalized_ndim = len(normalized_shape)
    torch._check(
        normalized_ndim >= 1,
        lambda: "Expected normalized_shape to be at least 1-dimensional, i.e., "
        + "containing at least one element, but got normalized_shape = "
        + str(normalized_shape),
    )
    # torch.Size([1, 2, 3]) == [1, 2, 3] evaluates to False
    # while torch.Size([1, 2, 3]) == (1, 2, 3) is True
    # therefore we use tuple(normalized_shape)
    torch._check(
        weight is None or weight.shape == tuple(normalized_shape),
        lambda: "Expected weight to be of same shape as normalized_shape, but got "
        + "weight of shape "
        + str(weight.shape)  # type: ignore[union-attr]
        + " and normalized_shape = "
        + str(normalized_shape),
    )
    torch._check(
        bias is None or bias.shape == tuple(normalized_shape),
        lambda: "Expected bias to be of same shape as normalized_shape, but got "
        + "bias of shape "
        + str(bias.shape)  # type: ignore[union-attr]
        + " and normalized_shape = "
        + str(normalized_shape),
    )
    torch._check(
        input.ndim >= normalized_ndim
        and input.shape[(input.ndim - normalized_ndim) :] == tuple(normalized_shape),
        lambda: "Given normalized_shape="
        + str(normalized_shape)
        + ", expected input with shape "
        + str(normalized_shape)
        + ", but got input of size "
        + str(input.shape),
    )

    input = input.contiguous()
    if weight is not None:
        weight = weight.contiguous()
    if bias is not None:
        bias = bias.contiguous()

    axis = input.ndim - normalized_ndim
    reduction_dims = list(range(axis, input.ndim))
    out, mean, rstd = _normalize(input, reduction_dims, eps)

    if weight is None and bias is not None:
        out = out + bias
    elif weight is not None and bias is None:
        out = out * weight
    elif weight is not None and bias is not None:
        out = out * weight + bias

    out = _maybe_convert_to_dtype(out, input.dtype)  # type: ignore[assignment]
    if input.device.type in ["cpu", "mtia"]:
        mean = _maybe_convert_to_dtype(mean, input.dtype)  # type: ignore[assignment]
        rstd = _maybe_convert_to_dtype(rstd, input.dtype)  # type: ignore[assignment]
    return (out, mean, rstd)


@torch._subclasses.fake_impls.register_op_impl(aten.native_layer_norm.default)
def native_layer_norm_fake(fake_mode, func, *args, **kwargs):
    return native_layer_norm(*args)


# TODO: Adding this as a meta function causes functorch tests to fail when compiled with debug mode.
# test/test_eager_transforms.py::TestFunctionalizeCPU::test_functionalize_fx_transpose_simple_cpu
@register_decomposition(aten.permute)
def permute(a: TensorLikeType, *dims) -> TensorLikeType:
    _permutation = utils.canonicalize_dims(
        a.ndim, utils.extract_dims_from_varargs(dims)
    )
    return prims.transpose(a, _permutation)


@register_decomposition(aten.renorm)
@out_wrapper()
def renorm(
    input: TensorLikeType, p: RealNumberType, dim: int, maxnorm: RealNumberType
) -> TensorLikeType:
    torch._check(not isinstance(p, complex), lambda: "renorm: p must be real-valued")
    torch._check(p > 0, lambda: "renorm: non-positive norm not supported")
    torch._check(
        not isinstance(maxnorm, complex), lambda: "renorm: maxnorm must be real-valued"
    )
    torch._check(
        maxnorm >= 0, lambda: f"renorm: expected maxnorm to be >= 0 but got {maxnorm}"
    )
    ndim = input.ndim
    torch._check(
        ndim > 1,
        lambda: f"renorm: input needs at least 2 dimensions, got {ndim} dimensions",
    )

    dim = utils.canonicalize_dim(ndim, dim)
    reduce_dims = list(range(ndim))
    del reduce_dims[dim]

    # For half and bfloat16, calculate norm in float precision then cast
    # normalization factor to half
    acc_type = utils.get_computation_dtype(input.dtype)
    if acc_type != input.dtype:
        norm = torch.linalg.vector_norm(
            input, p, reduce_dims, keepdim=True, dtype=acc_type
        )
    else:
        norm = torch.linalg.vector_norm(input, p, reduce_dims, keepdim=True)

    eps = 1e-7
    norm_factor = torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0)
    if acc_type != input.dtype:
        norm_factor = prims.convert_element_type(norm_factor, input.dtype)
    return (input * norm_factor).contiguous()


# CompositeImplicitAutograd - don't register decomp
@aten.stft.center.py_impl(DispatchKey.CompositeImplicitAutograd)
def stft(
    input: Tensor,
    n_fft: int,
    hop_length: Optional[int] = None,
    win_length: Optional[int] = None,
    window: Optional[Tensor] = None,
    center: bool = True,
    pad_mode: str = "reflect",
    normalized: bool = False,
    onesided: Optional[bool] = None,
    return_complex: Optional[bool] = None,
    align_to_window: Optional[bool] = None,
) -> Tensor:
    torch._check(
        window is None or window.device == input.device,
        lambda: (
            f"stft input and window must be on the same device but got self on {input.device}"
            + f" and window on {window.device}"  # type: ignore[union-attr]
        ),
    )
    torch._check(
        not center or align_to_window is None,
        "stft only supports align_to_window for center = False.",
    )

    hop_length_ = hop_length if hop_length is not None else n_fft // 4
    win_length_ = win_length if win_length is not None else n_fft

    if return_complex is None:
        return_complex_ = input.is_complex() or (
            window is not None and utils.is_complex_dtype(window.dtype)
        )
        torch._check(
            return_complex_,
            (
                "stft requires the return_complex parameter be given for real inputs, "
                + "and will further require that return_complex=True in a future PyTorch release."
            ),
        )
    else:
        return_complex_ = return_complex

    torch._check(
        utils.is_float_dtype(input.dtype) or utils.is_complex_dtype(input.dtype),
        lambda: "stft expected a tensor of floating point or complex values",
    )
    torch._check(1 <= input.ndim <= 2, lambda: "stft expected a 1D or 2D tensor")

    original_ndim = input.ndim
    if original_ndim == 1:
        input = input.unsqueeze(0)

    if center:
        extra_dims = 3 - input.ndim
        pad_amount = n_fft // 2
        extended_shape = [*itertools.repeat(1, extra_dims), *input.shape]
        input = aten.pad(input.view(extended_shape), [pad_amount, pad_amount], pad_mode)
        input = input.view(input.size()[extra_dims:])

    length = input.size(1)
    torch._check(
        0 < n_fft <= length,
        lambda: f"stft expected 0 < n_fft <= {length}, but got n_fft={n_fft}",
    )
    torch._check(
        hop_length_ > 0,
        lambda: f"stft expected hop_length > 0 but got hop_length={hop_length_}",
    )
    torch._check(
        0 < win_length_ <= n_fft,
        lambda: f"stft expected 0 < win_length <= n_fft but got win_length={win_length_}",
    )
    torch._check(
        window is None or window.shape == (win_length_,),
        lambda: (
            f"expected a 1D window tensor of size equal to win_length={win_length_}, "
            + f"but got window with size {window.shape}"  # type: ignore[union-attr]
        ),
    )

    if win_length_ < n_fft:
        if window is None:
            window = torch.ones(win_length_, dtype=input.dtype, device=input.device)
        left = (n_fft - win_length_) // 2
        window = aten.constant_pad_nd(window, [left, n_fft - win_length_ - left])

    input = input.unfold(dimension=-1, size=n_fft, step=hop_length_)
    if not center and align_to_window:
        input_pad_amount = (n_fft - win_length_) // 2
        input = aten.pad(input, [input_pad_amount, input_pad_amount], pad_mode)
    if window is not None:
        input = input * window

    complex_fft = utils.is_complex_dtype(input.dtype)
    onesided = onesided if onesided is not None else not complex_fft
    norm = "ortho" if normalized else None
    if onesided:
        torch._check(
            not complex_fft,
            lambda: "Cannot have onesided output if window or input is complex",
        )
        out = torch.fft.rfft(input, dim=-1, norm=norm)
    else:
        out = torch.fft.fft(input, dim=-1, norm=norm)

    out.transpose_(1, 2)

    if original_ndim == 1:
        out = out.squeeze_(0)

    return out if return_complex_ else torch.view_as_real(out)


# CompositeImplicitAutograd - don't register decomp
@aten.istft.default.py_impl(DispatchKey.CompositeImplicitAutograd)
def istft(
    input: Tensor,
    n_fft: int,
    hop_length: Optional[int] = None,
    win_length: Optional[int] = None,
    window: Optional[Tensor] = None,
    center: bool = True,
    normalized: bool = False,
    onesided: Optional[bool] = None,
    length: Optional[int] = None,
    return_complex=False,
) -> Tensor:
    torch._check(
        window is None or window.device == input.device,
        lambda: (
            f"istft input and window must be on the same device but got self on {input.device}"
            + f" and window on {window.device}"  # type: ignore[union-attr]
        ),
    )

    hop_length_ = hop_length if hop_length is not None else n_fft // 4
    win_length_ = win_length if win_length is not None else n_fft

    torch._check(
        utils.is_complex_dtype(input.dtype),
        lambda: (
            "istft input and window must be on the same device but got self on "
            + f"{input.device} and window on {window.device}"  # type: ignore[union-attr]
        ),
    )
    n_frames = input.size(-1)
    fft_size = input.size(-2)

    expected_output_signal_len = n_fft + hop_length_ * (n_frames - 1)
    torch._check(input.numel() > 0, lambda: "istft input tensor cannot be empty")
    torch._check(
        2 <= input.ndim <= 3,
        lambda: f"istft expected a tensor with 2 or 3 dimensions, but got {input.ndim}",
    )
    onesided_ = onesided if onesided is not None else fft_size != n_fft

    if onesided_:
        torch._check(
            n_fft // 2 + 1 == fft_size,
            lambda: (
                "istft expected the frequency dimension (3rd to the last) of the input tensor "
                + "to match n_fft / 2 + 1 when onesided=True, but got {fft_size}"
            ),
        )
    else:
        torch._check(
            n_fft == fft_size,
            lambda: (
                "istft expected the frequency dimension (3rd to the last) of the input tensor "
                + "to match n_fft when onesided=False, but got {fft_size}",
            ),
        )

    torch._check(
        0 < hop_length_ <= win_length_,
        lambda: "istft expected 0 < hop_length <= win_length",
    )
    torch._check(
        0 < win_length_ <= n_fft, lambda: "istft expected 0 < win_length <= n_fft"
    )
    torch._check(
        window is None or window.shape == (win_length_,),
        lambda: "Invalid window shape. window has to be 1D and length of `win_length`",
    )

    if window is None:
        real_dtype = utils.corresponding_real_dtype(input.dtype)
        window_ = torch.ones(win_length_, dtype=real_dtype, device=input.device)
    else:
        window_ = window

    if win_length_ != n_fft:
        left = (n_fft - win_length_) // 2
        window_ = aten.constant_pad_nd(window_, (left, n_fft - win_length_ - left), 0)

    original_ndim = input.ndim
    if input.ndim == 2:
        input = input.unsqueeze(0)

    input = input.transpose(1, 2)
    norm = "ortho" if normalized else None
    if return_complex:
        torch._check(
            not onesided_,
            lambda: "cannot have onesided output if window or input is complex",
        )
        input = torch.fft.ifft(input, dim=-1, norm=norm)
    else:
        torch._check(
            window is None or not utils.is_complex_dtype(window.dtype),
            lambda: "Complex windows are incompatible with return_complex=False",
        )
        if not onesided_:
            input = input.narrow(dim=-1, start=0, length=n_fft // 2 + 1)
        input = torch.fft.irfft(input, dim=-1, norm=norm)

    assert input.size(2) == n_fft

    y_tmp = input * window_.view([1, 1, n_fft])
    y = aten.unfold_backward(
        y_tmp,
        input_sizes=(y_tmp.size(0), expected_output_signal_len),
        dim=1,
        size=n_fft,
        step=hop_length_,
    )
    window_envelop = aten.unfold_backward(
        window_.pow(2).expand((1, n_frames, n_fft)),
        input_sizes=(y_tmp.size(0), expected_output_signal_len),
        dim=1,
        size=n_fft,
        step=hop_length_,
    )

    assert expected_output_signal_len == y.size(1)
    assert expected_output_signal_len == window_envelop.size(1)

    start = n_fft // 2 if center else 0
    if length is not None:
        end = start + length
    elif center:
        end = expected_output_signal_len - n_fft // 2
    else:
        end = expected_output_signal_len

    length = max(0, end - start)
    y = y.narrow(dim=1, start=start, length=length)
    window_envelop = window_envelop.narrow(dim=1, start=start, length=length)

    y = y / window_envelop
    if original_ndim == 2:
        y = y.squeeze(0)

    if end > expected_output_signal_len:
        warnings.warn(
            "The length of signal is shorter than the length parameter. Result is being "
            + "padded with zeros in the tail. Please check your center and hop_length settings"
        )
        y = aten.constant_pad_nd(y, (0, end - expected_output_signal_len), 0)
    return y


# Get the new shape and stride after applying unfold to an input tensor
def _get_unfold_shape_stride(
    a_shape: ShapeType, a_stride: StrideType, dimension: int, size: int, step: int
):
    a_ndim = len(a_shape)
    dim = utils.canonicalize_dim(a_ndim, dimension, wrap_scalar=True)
    max_size = 1 if a_ndim == 0 else a_shape[dim]
    last_stride = 1 if a_ndim == 0 else a_stride[dim]

    torch._check(
        size <= max_size,
        lambda: f"Maximum size for tensor at dimension {dim} is {max_size} but size is {size}",
    )

    torch._check(
        step > 0,
        lambda: f"Step is {step} but must be > 0",
    )

    shape = list(a_shape)
    strides = list(a_stride)
    shape.append(size)
    strides.append(last_stride)
    if dim < a_ndim:
        shape[dim] = (shape[dim] - size) // step + 1
        strides[dim] *= step
    return shape, strides


@register_decomposition(aten.repeat)
@out_wrapper()
def repeat(a: Tensor, *repeat_shape) -> Tensor:
    repeat_shape = utils.extract_shape_from_varargs(repeat_shape, validate=False)
    torch._check(
        len(repeat_shape) >= len(a.shape),
        lambda: "repeat: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
    )

    if len(repeat_shape) == 0:
        return torch.clone(a)

    num_new_dimensions = len(repeat_shape) - a.ndim
    padded_shape = [1] * num_new_dimensions
    for dim_size in a.shape:
        padded_shape.append(dim_size)

    target_shape = tuple(
        padded_size * repeat_size
        for padded_size, repeat_size in zip(padded_shape, repeat_shape)
    )

    # return an empty tensor if one of the repeat_shape dimensions is zero
    if 0 in repeat_shape:
        return torch.empty(
            target_shape,
            dtype=a.dtype,
            device=a.device,
            requires_grad=a.requires_grad,
            memory_format=utils.suggest_memory_format(a),
        )

    urtensor_shape = target_shape
    urtensor_stride = utils.make_contiguous_strides_for(target_shape)
    for dim, dim_size in enumerate(padded_shape):
        # repeat each dimension by using unfold_copy operation
        urtensor_shape, urtensor_stride = _get_unfold_shape_stride(
            urtensor_shape, urtensor_stride, dim, dim_size, max(dim_size, 1)
        )

    # derive permute order by sorting urtensor strides
    enumerated_stride = list(enumerate(urtensor_stride))
    enumerated_stride.sort(key=operator.itemgetter(1), reverse=True)
    permute_order, _sorted_stride = zip(*enumerated_stride)

    # add new and expand dimensions according to urtensor
    repeat_xtensor = a.expand(urtensor_shape)

    # clone tensor to concretize expanded dimensions
    cloned_result = torch.clone(repeat_xtensor)

    # transpose axis so strides are in sorted order
    permuted_result = cloned_result.permute(permute_order)

    # reshape to get contiguous tensor with correct target shape
    return permuted_result.reshape(target_shape)


def _reshape_view_helper(a: TensorLikeType, *shape, allow_copy: bool) -> TensorLikeType:
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq

    # Creates a valid shape
    shape = utils.extract_shape_from_varargs(shape, validate=False)
    # Reshape may be given a shape with a -1 length
    # This indicates that the dimension's length should be inferred
    shape = utils.infer_size(shape, a.numel())

    # Special-cases tensors with no elements
    if guard_size_oblivious(a.numel() == 0):
        return as_strided(a, shape, utils.make_contiguous_strides_for(shape))

    # Special-cases reshaping zero dim tensors
    if a.ndim == 0:
        _a = a
        for length in shape:
            assert length == 1
            _a = unsqueeze(_a, -1)
        if _a is a:
            return prims.view_of(a)
        else:
            return _a

    # Special-cases reshaping to zero dim tensors
    if len(shape) == 0:
        _a = a
        for length in a.shape:
            assert length == 1
            _a = squeeze(_a, -1)
        if _a is a:
            return prims.view_of(a)
        else:
            return _a

    if a.is_contiguous():
        # Special-cases for nd_to_1d
        if len(shape) == 1 and a.ndim > 1:
            return torch.as_strided(a, [a.numel()], [1])
        # Special-cases for 1d_to_2d
        if len(shape) == 2 and a.ndim == 1:
            dim0 = shape[0]
            dim1 = shape[1]
            return torch.as_strided(a, [dim0, dim1], [dim1, 1])

    # Handles general case: a 1+D tensor reshaped into a distinct 1+D shape

    # NOTE [Reshape Algorithm]
    # This algorithm works by attempting to greedily construct the desired dimensions in
    # the output shape, left to right. It does this by, conceptually, accumulating
    # dimensions of the original tensor, also left to right, until the dimension
    # can be constructed using prims.split_dim.
    # The algorithm also has special handling for tail squeezes/unsqueezes, like
    # if a reshape from (5, 5) to (5, 5, 1) or vice versa.
    #
    # This algorithm does not flatten the original tensor and then split dims as appropriate
    # because that would create copies more often than this algorithm. flatten is the only
    # operation below which can create a view or a copy, and while it prefers creating
    # views it may sometimes create a copy if the tensor's strides do not permit a view.
    # As a result, this algorithm tries to minimize flattening.
    #
    # Note that a better version of this algorithm may exist. Regions which could be
    # flattened without creating a copy can be identified in advance, and that might
    # allow fewer flatten calls or faster short-circuiting to make a copy.
    idx = 0
    a_ = a
    for length in shape:
        # Handles tail unsqueezes
        if idx >= a_.ndim:
            assert length == 1
            last_dim = a_.ndim - 1
            # NOTE: using split_dim instead of unsqueeze may seem silly here,
            # but it's necessary to get the strides correct
            a_ = prims.split_dim(a_, last_dim, a_.shape[last_dim])
            idx = idx + 1
            continue

        # Skips dimensions that are already the correct length
        if guard_size_oblivious(length == a_.shape[idx]):
            idx = idx + 1
            continue

        # Gathers enough original dimensions such that this new dimension can be created
        # Note that this accumulation will terminate because we've verified a and the shape
        # specify the same number of elements above
        accum = a_.shape[idx]
        end = idx
        while guard_size_oblivious(accum % length != 0):
            end = end + 1
            accum = accum * a_.shape[end]
        if end != idx:
            # NOTE: in this case multiple dimensions must be flatten to create the desired dimension
            # This flattening is why reshape sometimes creates a copy -- because flattening
            # may return a view of a copy

            # Checks if collapse can be a view and short-circuits to copying reshape if it can't
            new_shape, _new_strides = prims._collapse_view_helper(a_, idx, end)
            if new_shape is None:
                if allow_copy:
                    return prims.reshape(a, shape)

                msg = f"Cannot view a tensor with shape {a.shape} and strides {a.stride()} as a tensor with shape {shape}!"
                raise ValueError(msg)

            a_ = flatten(a_, idx, end)

        # Splits the (possibly flattened) dimension to create the desired dim length
        if guard_size_oblivious(accum != length):
            a_ = prims.split_dim(a_, idx, length)

        idx = idx + 1

    # Squeezes tail
    while idx < a_.ndim:
        torch._check(
            a_.shape[idx] == 1,
            lambda: f"a.size({idx}) expected to be 1 but got {a_.shape[idx]}",
        )
        a_ = squeeze(a_, idx)

    if a_ is a:
        return prims.view_of(a)
    else:
        return a_


# CompositeImplicitAutograd - don't register decomp
# NOTE: shape is a vararg because Tensor.reshape can be called with as
# Tensor.reshape(a, b, c) or Tensor.reshape((a, b, c)) Function call
# torch.reshape doesn't support unpacked shapes
def reshape(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
    return _reshape_view_helper(a, *shape, allow_copy=True)


# CompositeImplicitAutograd - don't register decomp
def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
    return self.reshape(other.size())


@register_decomposition(aten.roll)
@out_wrapper()
def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
    """Reference implementation of :func:`torch.roll`."""
    dims = utils.canonicalize_dims(a.ndim, dims)
    # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
    if not isinstance(shifts, Iterable):
        shifts = (shifts,)
    if not isinstance(dims, Iterable):
        dims = (dims,)

    # Avoid modulo by zero
    if a.numel() == 0:
        # Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
        return a.clone()

    if a.dim() == 0 and len(dims) > 0:
        raise IndexError(
            f"Dimension specified as {dims[0]} but tensor has no dimensions"
        )

    len_shifts = len(shifts)
    len_dims = len(dims)
    if len_shifts != 1 or len_dims != 1:
        if len_shifts == 0:
            raise RuntimeError("`shifts` required")
        # Takes care of the case when dims is not specified (default)
        # By default, the tensor is flattened before shifting, after which the original shape is restored
        if len_dims == 0 and len_shifts == 1:
            return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
        if len_shifts != len_dims:
            raise RuntimeError(
                f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
            )
        assert len_dims > 1
        tail_shifts = shifts[1:]
        tail_dims = dims[1:]
        first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
        return torch.roll(first_dim_rolled, tail_shifts, tail_dims)

    # This path is taken when only one dimension is rolled
    # For example to get `first_dim_rolled` above
    dim = dims[0]
    size = a.shape[dim]
    start = (size - shifts[0]) % size
    idx = torch.arange(size, device=a.device)
    return a.index_select(dim, torch.fmod(start + idx, size))


@register_decomposition(aten.rot90)
@out_wrapper()
def rot90(
    a: TensorLikeType, k: int = 1, dims: DimsSequenceType = (0, 1)
) -> TensorLikeType:
    """Reference implementation of :func:`torch.rot90`."""
    if len(dims) != 2:
        raise RuntimeError(
            f"expected total rotation dims == 2, but got dims = {len(dims)}"
        )
    if a.ndim < 2:
        raise RuntimeError(f"expected total dims >= 2, but got total dims = {a.ndim}")

    # Do this after the initial checks to be compatible with the behavior in
    # core.
    dims = utils.canonicalize_dims(a.ndim, dims)

    if dims[0] == dims[1]:
        raise RuntimeError(
            f"expected rotation dims to be different, but got dim0 = {dims[0]} and dim1 = {dims[1]}"
        )
    k = k % 4  # Rotation direction is from the second towards the first axis for k < 0
    if k == 1:
        return torch.transpose(torch.flip(a, (dims[1],)), dims[0], dims[1])
    elif k == 2:
        return torch.flip(a, dims)
    elif k == 3:
        return torch.transpose(torch.flip(a, (dims[0],)), dims[0], dims[1])
    else:
        return a.clone(memory_format=torch.contiguous_format)


def _check_stack_inputs(tensors: TensorSequenceType) -> None:
    entry_shape = tensors[0].shape
    for i in range(1, len(tensors)):
        assert tensors[i].shape == entry_shape, (
            f"stack expects each tensor to be equal size, but got {entry_shape} at entry 0 "
            f"and {tensors[i].shape} at entry {i}"
        )


@register_decomposition(aten.stack)
@out_wrapper()
def stack(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
    assert len(tensors) > 0, "stack expects a non-empty TensorList"
    wrapped_dim = utils.canonicalize_dim(tensors[0].ndim + 1, dim)
    # Refs need sparse support to check other condition
    if wrapped_dim < tensors[0].ndim:  # and not tensors[0].is_sparse:
        _check_stack_inputs(tensors)
        result_sizes = list(tensors[0].shape)
        result_sizes.insert(wrapped_dim, len(tensors))
        out = torch.cat(tensors, wrapped_dim)
        return out.view(result_sizes)

    # If dim == tensors[0].ndim, view cannot efficiently handle it
    return torch.cat([t.unsqueeze(wrapped_dim) for t in tensors], dim)


# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def softmax(
    a: TensorLikeType,
    dim: int,
    dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
    result_dtype = dtype or a.dtype
    computation_dtype = utils.get_computation_dtype(result_dtype)
    a_ = _maybe_convert_to_dtype(a, computation_dtype)
    if a.numel() == 0:
        a_exp = exp(a_)
    else:
        a_max = amax(a_, dim, keepdim=True)
        a_exp = exp(a_ - a_max)
    return _maybe_convert_to_dtype(
        true_divide(a_exp, sum(a_exp, dim, keepdim=True)), result_dtype
    )  # type: ignore[return-value]


# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def hstack(tensors: TensorSequenceType) -> TensorLikeType:
    torch._check(len(tensors) > 0, lambda: "hstack expects a non-empty TensorList")
    aligned_tensors = atleast_1d(*tensors)
    if aligned_tensors[0].ndim == 1:
        return cat(aligned_tensors, 0)
    return cat(aligned_tensors, 1)


# CompositeImplicitAutograd - don't register decomp
@out_wrapper()
def vstack(tensors: TensorSequenceType) -> TensorLikeType:
    torch._check(len(tensors) > 0, lambda: "vstack expects a non-empty TensorList")
    aligned_tensors = atleast_2d(*tensors)
    return cat(aligned_tensors, 0)


# CompositeImplicitAutograd - don't register decomp
def unflatten(a: TensorLikeType, dim: int, sizes: ShapeType) -> TensorLikeType:
    dim = utils.canonicalize_dim(a.ndim, dim)
    torch._check(len(sizes) != 0, lambda: "unflatten: sizes must be non-empty")
    return a.view(tuple(a.shape[:dim]) + tuple(sizes) + tuple(a.shape[dim + 1 :]))


@register_decomposition(aten.unbind)
def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType:
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

    dim = utils.canonicalize_dim(t.ndim, dim)
    torch._check_index(
        len(t.shape) > 0,
        lambda: "Dimension specified as 0 but tensor has no dimensions",
    )
    if guard_size_oblivious(t.shape[dim] == 0):
        return ()
    else:
        return tuple(
            torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim)
        )


@out_wrapper()
def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
    return x.clone(memory_format=torch.contiguous_format).index_copy_(
        dim, index, tensor
    )


def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
    dim = utils.canonicalize_dims(x.ndim, dim)
    torch._check(
        index.ndim <= 1,
        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
    )
    # Treat scalars as elements of \R^1
    y = x.unsqueeze(0) if x.ndim == 0 else x
    idx = (slice(None),) * dim + (index,)
    y[idx] = tensor
    return x


@register_decomposition(aten.index_fill)
@out_wrapper()
def index_fill(
    x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
):
    return _index_fill(x, dim, index, value, inplace=False)


@register_decomposition(aten.index_fill_)
def index_fill_(
    x: TensorLike, dim: int, index: TensorLike, value: Union[NumberType, TensorLike]
):
    return _index_fill(x, dim, index, value, inplace=True)


def _index_fill(
    x: TensorLike,
    dim: int,
    index: TensorLike,
    value: Union[NumberType, TensorLike],
    *,
    inplace: bool,
):
    torch._check(
        index.ndim <= 1,
        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
    )
    if isinstance(value, TensorLike):
        torch._check(
            value.ndim == 0,
            lambda: "Only supports 0-dimensional value tensor. "  # type: ignore[union-attr]
            f"Got a tensor with {value.ndim} dimensions.",
        )  # type: ignore[arg-type]
    else:
        value = torch.scalar_tensor(
            value,
            dtype=x.dtype,
            layout=x.layout,
            device=x.device,  # type: ignore[arg-type]
        )

    # index_copy has some unnecessary preconditions when x is a scalar. We do this to work through them
    zero_dim = x.ndim == 0
    y = x.unsqueeze(0) if zero_dim else x
    # index_copy does not broadcast on value so we have to do it manually
    shape = list(y.shape)
    shape[dim] = index.numel()
    value = value.expand(shape)
    index_copy = Tensor.index_copy_ if inplace else torch.index_copy
    out = index_copy(y, dim, index, value)  # type: ignore[operator]
    if inplace:
        return x
    else:
        if zero_dim:
            # The clone is necessary so that it returns a fresh tensor rather than a view
            out = out.squeeze(0).clone()
        # index_fill preserves the strides. index_copy always returns contiguous tensors
        if out.stride() != x.stride():
            new_out = torch.empty_like(x)
            new_out.copy_(out)
            out = new_out
        return out


@out_wrapper()
def index_add(
    x: TensorLike,
    dim: int,
    index: TensorLike,
    tensor: TensorLike,
    *,
    alpha: NumberType = 1,
):
    # index_add always returns a new contiguous tensor
    return x.clone(memory_format=torch.contiguous_format).index_add_(
        dim,
        index,
        tensor,
        alpha=alpha,  # type: ignore[arg-type]
    )


@register_decomposition(aten.index_select)
@out_wrapper()
def index_select(x: TensorLike, dim: int, index: TensorLike):
    dim = utils.canonicalize_dims(x.ndim, dim)
    torch._check(
        index.ndim <= 1,
        lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
    )
    if index.ndim == 0:
        index = index.unsqueeze(0)
    if x.ndim == 0:
        # Treat scalars as elements of \R^1
        # We cannot use x[idx] here as it accesses item() (??), hence this awkward construction
        return torch.empty_like(x).index_copy(0, index, x.expand_as(index))

    idx = (slice(None),) * dim + (index,)
    return x[idx]


@register_decomposition(aten.squeeze.dims)
def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious

    if dim is None:
        dims = tuple(idx for idx, size in enumerate(a.shape) if size == 1)
        return prims.squeeze(a, dims) if dims else prims.view_of(a)

    ndim = a.ndim
    dim = utils.canonicalize_dims(ndim, dim)
    dims = (dim,) if isinstance(dim, Dim) else dim
    # Short-circuits if the tensor has no dimensions
    if ndim == 0:
        assert len(dims) == 0 or dims == (0,)
        return prims.view_of(a)

    # Note: squeeze does not modify tensors when the given dim is not a dimension of length 1
    dims = tuple(d for d in dims if guard_size_oblivious(a.shape[d] == 1))
    if len(dims) == 0:
        return prims.view_of(a)
    if len(dims) == 1:
        return prims.squeeze(a, dims)
    dims_list = list(dims)
    dims_list = sorted(dims_list, reverse=True)
    for i in dims_list:
        a = squeeze(a, i)
    return a


@register_decomposition(aten.split_with_sizes)
def split_with_sizes(
    self: Tensor, split_sizes: list[int], dim: int = 0
) -> list[Tensor]:
    # NB: Perform the check_is_size tests first so that the
    # sum test does not try to do a replacement
    for i in range(len(split_sizes)):
        torch._check_is_size(
            split_sizes[i],
            lambda: "split_with_sizes expects split_sizes have only non-negative entries",
        )
    torch._check_with(
        ValueError,
        builtins.sum(split_sizes) == self.shape[dim],
        lambda: f"Split sizes add up to {builtins.sum(split_sizes)} but got the tensor's size of {self.shape[dim]}",
    )

    splits = []
    offset = self.storage_offset()

    for split_size in split_sizes:
        new_shape = list(self.shape)
        new_shape[dim] = split_size
        # We reimplement narrow here to avoid a lot of checks in the
        # decomposition of narrow which calls slice_in_dim and slice
        splits.append(self.as_strided(new_shape, self.stride(), offset))
        offset = offset + self.stride()[dim] * split_size
    return splits


# Note: does not work with TensorMetas because of data-dependent control-flow
# CompositeImplicitAutograd - don't register decomp
def tensor_split(
    a: TensorLikeType,
    indices_or_sections: Union[Tensor, DimsType],
    dim: int = 0,
) -> tuple[TensorLikeType, ...]:
    _dim = utils.canonicalize_dim(a.ndim, dim)
    if a.ndim == 0:
        msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
        raise ValueError(msg)

    # If indices_or_sections is a tensor, it must be a CPU Long tensor
    if isinstance(indices_or_sections, TensorLike):
        if not indices_or_sections.device.type == "cpu":
            msg = (
                f"tensor_split: if indices_or_sections is a tensor it must be on the CPU, "
                f"but received one on {indices_or_sections.device}"
            )
            raise ValueError(msg)
        if indices_or_sections.dtype != torch.long:
            msg = (
                "tensor_split: if indices_or_sections is a tensor it must have long dtype, "
                f" but received one with dtype {indices_or_sections.dtype}"
            )
            raise ValueError(msg)

    # Case 0 -- indices_or_sections is an integer or a scalar tensor n and a is split along dim into n parts of equal-ish length
    if isinstance(indices_or_sections, IntLike) or (
        isinstance(indices_or_sections, TensorLike) and indices_or_sections.ndim == 0
    ):
        sections: int = (
            indices_or_sections  # type: ignore[assignment]
            if isinstance(indices_or_sections, Number)
            else indices_or_sections.item()
        )

        if sections <= 0:
            msg = f"tensor_split: number of sections must be greater than 0, but was {sections}"
            raise ValueError(msg)

        dim_size = a.shape[_dim]
        min_split_size = math.floor(dim_size / sections)
        num_splits_one_extra = dim_size % sections

        split_sizes = []
        for split_idx in range(sections):
            split_size = (
                min_split_size + 1
                if (split_idx < num_splits_one_extra)
                else min_split_size
            )
            split_sizes.append(split_size)

        return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim))
    # Case 1 -- indices_or_sections is a sequence of integers or a 1D tensor describing the splits
    else:
        indices = indices_or_sections
        if isinstance(indices_or_sections, TensorLike):
            if indices_or_sections.ndim != 1:
                msg = (
                    "tensor_split: non-scalar indices_or_sections tensors must have only one dimension, "
                    f"but received a tensor with {indices_or_sections.ndim} dimensions"
                )
                raise ValueError(msg)

            indices = indices_or_sections.tolist()

        indices = [0] + list(indices) + [a.shape[_dim]]
        split_sizes = [indices[i + 1] - indices[i] for i in range(len(indices) - 1)]
        return tuple(aten.split_with_sizes(a, split_sizes, dim=_dim))


# CompositeImplicitAutograd - don't register decomp
def hsplit(
    a: TensorLikeType, indices_or_sections: DimsType
) -> tuple[TensorLikeType, ...]:
    torch._check(
        a.ndim >= 1,
        lambda: (
            "torch.hsplit requires a tensor with at least 1 dimension, but got a tensor with "
            + str(a.ndim)
            + " dimensions!"
        ),
    )
    dim = 0 if a.ndim == 1 else 1
    if isinstance(indices_or_sections, IntLike):
        split_size = indices_or_sections
        torch._check(
            (split_size != 0 and a.shape[dim] % split_size == 0),
            lambda: (
                "torch.hsplit attempted to split along dimension "
                + str(dim)
                + ", but the size of the dimension "
                + str(a.shape[dim])
                + " is not divisible by the split_size "
                + str(split_size)
                + "!"
            ),
        )
        return tensor_split(a, split_size, dim)

    torch._check_type(
        isinstance(indices_or_sections, (list, tuple)),
        lambda: (
            "hsplit(): received an invalid combination of arguments. "
            "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
            f"but got type {type(indices_or_sections)}"
        ),
    )

    split_sizes = indices_or_sections
    return tensor_split(a, split_sizes, dim)


# CompositeImplicitAutograd - don't register decomp
def vsplit(
    a: TensorLikeType, indices_or_sections: DimsType
) -> tuple[TensorLikeType, ...]:
    torch._check(
        a.ndim >= 2,
        lambda: (
            "torch.vsplit requires a tensor with at least 2 dimension, but got a tensor with "
            + str(a.ndim)
            + " dimensions!"
        ),
    )
    if isinstance(indices_or_sections, IntLike):
        split_size = indices_or_sections
        torch._check(
            (split_size != 0 and a.shape[0] % split_size == 0),
            lambda: (
                f"torch.vsplit attempted to split along dimension 0"
                f", but the size of the dimension "
                f"{a.shape[0]}"
                f" is not divisible by the split_size "
                f"{split_size}"
                f"!"
            ),
        )
        return tensor_split(a, split_size, 0)

    torch._check_type(
        isinstance(indices_or_sections, (list, tuple)),
        lambda: (
            "vsplit(): received an invalid combination of arguments. "
            "Expected indices_or_sections to be of type int, list of ints or tuple of ints "
            f"but got type {type(indices_or_sections)}"
        ),
    )

    split_sizes = indices_or_sections
    return tensor_split(a, split_sizes, 0)


@register_decomposition(aten.diag.out)
@out_wrapper()
def diag(
    self: TensorLikeType,
    offset: int = 0,
) -> TensorLikeType:
    ndim = self.dim()
    torch._check(
        ndim in (1, 2), lambda: f"diag(): Supports 1D or 2D tensors. Got {ndim}D"
    )
    if ndim == 1:
        return torch.diag_embed(self, offset)
    else:
        return torch.diagonal_copy(self, offset)


@register_decomposition(aten.diagonal_scatter)
@out_wrapper()
def diagonal_scatter(
    input: TensorLikeType,
    src: TensorLikeType,
    offset: int = 0,
    dim1: int = 0,
    dim2: int = 1,
) -> TensorLikeType:
    out = utils.clone_preserve_strides(input)
    diag = out.diagonal(offset, dim1, dim2)
    torch._check(
        diag.shape == src.shape,
        lambda: "expected src to have a size equal to the diagonal of the input."
        f"Got {src.shape} for a diagonal of shape {diag.shape}",
    )
    copy_to(diag, src)
    return out


@register_decomposition(aten.diagonal)
def diagonal(
    self: TensorLikeType,
    offset: int = 0,
    dim1: int = 0,
    dim2: int = 1,
) -> TensorLikeType:
    """
    Reference implementation of torch.diagonal
    """
    num_dims = self.dim()
    dim1 = utils.canonicalize_dim(idx=dim1, rank=num_dims)
    dim2 = utils.canonicalize_dim(idx=dim2, rank=num_dims)

    torch._check(
        dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
    )

    storage_offset = self.storage_offset()

    if offset >= 0:
        diag_size = max(min(self.size()[dim1], self.size()[dim2] - offset), 0)
    else:
        diag_size = max(min(self.size()[dim1] + offset, self.size()[dim2]), 0)

    if diag_size > 0:
        if offset >= 0:
            storage_offset += offset * self.stride()[dim2]
        else:
            storage_offset -= offset * self.stride()[dim1]

    sizes = [s for i, s in enumerate(self.size()) if i not in (dim1, dim2)]
    sizes.append(diag_size)

    strides = [s for i, s in enumerate(self.stride()) if i not in (dim1, dim2)]
    strides.append(self.stride()[dim1] + self.stride()[dim2])

    result = self.as_strided(size=sizes, stride=strides, storage_offset=storage_offset)

    return result


@register_decomposition(aten.diag_embed)
@out_wrapper()
def diag_embed(
    t: TensorLikeType,
    offset: int = 0,
    dim1: int = -2,
    dim2: int = -1,
) -> TensorLikeType:
    """
    Reference implementation of torch.diag_embed
    """
    # convert from negative dims
    rank = t.ndim + 1
    dim1 = utils.canonicalize_dim(rank=rank, idx=dim1)
    dim2 = utils.canonicalize_dim(rank=rank, idx=dim2)

    # as per the docs, exchanging dims is equivalent to changing the sign of
    # offset
    if dim1 > dim2:
        dim1, dim2 = dim2, dim1
        offset = -offset

    torch._check(
        dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
    )

    # as per the docs, the size of last dim is placed at dim1 and dim2
    last_dim = t.size(-1)

    if offset != 0:
        # add padding to match the new size
        t_shape = list(t.shape)
        t_shape[-1] = builtins.abs(offset)
        z = torch.zeros(t_shape, dtype=t.dtype, device=t.device, requires_grad=False)
        pair = (z, t) if offset > 0 else (t, z)
        t = torch.cat(pair, dim=-1)
        # make sure the diagonal always has the same size
        last_dim += builtins.abs(offset)

    # preserve original data, but place 1 at dim1 and move last dim to dim2
    t = t.unsqueeze(dim1).movedim(-1, dim2)

    # generate ranges shifting indices based on offset
    a_range = torch.arange(last_dim, device=t.device, dtype=torch.int64)
    b_range = torch.arange(
        offset, last_dim + offset, device=t.device, dtype=torch.int64
    )

    # broadcast
    cond = a_range == b_range.unsqueeze(-1)
    cond_shape = [last_dim if i in (dim1, dim2) else 1 for i in range(len(t.shape))]
    cond = cond.reshape(cond_shape)

    # aten.diag_embed always returns a new contiguous tensor
    # contiguous() is needed to correctly model the output stride
    return utils.mask_tensor(cond, t).contiguous()


@register_decomposition(aten.block_diag)
@out_wrapper()
def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType:
    """
    Reference implementation of torch.block_diag
    """
    tensors_2d = [
        tensor.view(1, -1) if tensor.dim() <= 1 else tensor for tensor in tensors
    ]

    ncols = builtins.sum(tensor.shape[1] for tensor in tensors_2d)
    device = tensors_2d[0].device

    result = []

    col_start = 0
    for i, tensor in enumerate(tensors_2d):
        torch._check(
            tensor.dim() == 2,
            lambda: "Input tensors must have 2 or fewer dimensions. "
            f"Input {i} has {tensor.dim()} dimensions",
        )
        torch._check(
            tensor.device == device,
            lambda: "Input tensors must all be on the same device. "
            f"Input 0 is on device {device} and input {i} is on device {tensor.device}.",
        )
        row, col = tensor.shape
        left = torch.zeros((row, col_start), device=device, dtype=tensor.dtype)
        right = torch.zeros(
            (row, ncols - col_start - col), device=device, dtype=tensor.dtype
        )
        result += [torch.cat((left, tensor, right), dim=1)]
        col_start += col

    return torch.cat(result, dim=0)


def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType:
    """
    This is used as an input to PythonRefInfo. `torch.block_diag`
    expects arguments splatted, but `aten.block_diag` expects only
    one argument that is a list of Tensors.
    """
    return _block_diag_iterable(tensors)  # type: ignore[arg-type]


# CompositeImplicitAutograd - don't register decomp
def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
    if a.ndim < 3:
        raise RuntimeError(
            f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
        )
    if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
        raise RuntimeError(
            "torch.dsplit attempted to split along dimension 2, "
            + f"but the size of the dimension {a.shape[2]} is not divisible by the split_size {sections}!"
        )
    return tensor_split(a, sections, 2)


@register_decomposition(aten.t.default)
def t(a: TensorLikeType):
    # TODO: Add sparse support
    # if a.is_sparse:
    #     sparse_dim = a.sparse_dim()
    #     dense_dim = a.dense_dim()
    #     if not (sparse_dim <= 2 and dense_dim == 0):
    #         raise RuntimeError(
    #             f"t() expects a tensor with <= 2 sparse and 0 dense dimensions, but got {sparse_dim} sparse and"
    #             f"{dense_dim} dense dimensions"
    #         )
    if a.ndim > 2:
        raise RuntimeError(
            f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D"
        )
    return torch.transpose(a, 0, 0 if a.ndim < 2 else 1)


# CompositeImplicitAutograd - don't register decomp
def T(a: TensorLikeType) -> TensorLikeType:
    # n != 2 && n != 0 is deprecated in regular PyTorch.
    torch._check(
        a.ndim in (0, 2),
        lambda: (
            "The use of `x.T` on tensors of dimension other than 0 or 2 "
            "to reverse their shape is not supported."
        ),
    )
    return a.t()


@register_decomposition(aten.alias)
def alias(a: TensorLikeType) -> TensorLikeType:
    return prims.view_of(a)


@register_decomposition(aten.transpose)
def transpose(a: TensorLikeType, dim0: int, dim1: int) -> TensorLikeType:
    _dim0, _dim1 = utils.canonicalize_dims(a.ndim, (dim0, dim1))  # type: ignore[misc]

    if a.ndim <= 1 or dim0 == dim1:
        return aten.alias.default(a)

    _permutation = list(range(0, a.ndim))
    _permutation[_dim0] = _dim1
    _permutation[_dim1] = _dim0
    return torch.permute(a, _permutation)


# Aliases for transpose
swap_axes = transpose


@register_decomposition(aten.unfold)
def unfold(
    self: TensorLikeType, dimension: int, size: int, step: int
) -> TensorLikeType:
    shape, strides = _get_unfold_shape_stride(
        self.shape, self.stride(), dimension, size, step
    )
    return self.as_strided(shape, strides)


@register_decomposition(aten.unfold_copy)
@out_wrapper()
def unfold_copy(self: TensorLikeType, dimension: int, size: int, step: int):
    return self.unfold(dimension, size, step).clone(
        memory_format=torch.contiguous_format
    )


def _cumsumprod_common(
    func,
    init,
    a: TensorLikeType,
    dim: int,
    *,
    dtype: Optional[torch.dtype] = None,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    # We implement all the kwargs of a reduction. ATen just handles dtype
    # nb. This decomposition may not be as efficient as a backend-specific implementation
    ndim = a.ndim
    dim = utils.canonicalize_dim(ndim, dim)
    if ndim == 0:
        return func(a.unsqueeze(0), dim=0, dtype=dtype, out=out)
    a = a.unsqueeze(dim + 1)
    rg = torch.arange(a.shape[dim], device=a.device)
    mask = rg.unsqueeze(1) <= rg
    for _ in range(ndim - dim - 1):
        mask = mask.unsqueeze(-1)
    masked_a = torch.where(mask, a, init)
    return func(masked_a, dim=dim, dtype=dtype, out=out)


@register_decomposition(aten.cumsum)
def cumsum(
    a: TensorLikeType,
    dim: int,
    *,
    dtype: Optional[torch.dtype] = None,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    return _cumsumprod_common(func=sum, init=0, a=a, dim=dim, dtype=dtype, out=out)


@register_decomposition(aten.cumprod)
def cumprod(
    a: TensorLikeType,
    dim: int,
    *,
    dtype: Optional[torch.dtype] = None,
    out: Optional[Tensor] = None,
) -> TensorLikeType:
    return _cumsumprod_common(func=prod, init=1, a=a, dim=dim, dtype=dtype, out=out)


# Note: although squeeze is documented as having the out= kwarg it doesn't
@register_decomposition(aten.unsqueeze)
def unsqueeze(a: TensorLikeType, dim: int) -> TensorLikeType:
    # Note that unsqueeze canonicalizes with rank + 1 because it allows
    # a new innermost dimension to be specified
    ndim = a.ndim + 1
    dim = utils.canonicalize_dim(ndim, dim)
    return prims.expand_dims(a, (dim,), ndim=ndim)


# NOTE: shape is a vararg because Tensor.reshape can be called with as
# Tensor.view(a, b, c) or Tensor.view((a, b, c)) Function call torch.view
# doesn't support unpacked shapes
# TODO: Turn this into a decomposition (currently fails on reshape meta tests)
@register_decomposition(aten.view.default)
def view(a: TensorLikeType, *shape: ShapeType) -> TensorLikeType:
    return _reshape_view_helper(a, *shape, allow_copy=False)


# CompositeImplicitAutograd - don't register decomp
def view_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
    return self.view(other.size())


# CompositeImplicitAutograd - don't register decomp
def ravel(a: TensorLikeType) -> TensorLikeType:
    return reshape(a, (-1,))


# CompositeImplicitAutograd - don't register decomp
# missing ref impl. for aten.gather
@out_wrapper()
def take_along_dim(
    a: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = None
) -> torch.Tensor:
    torch._check(
        a.ndim == indices.ndim,
        lambda: (
            "torch.take_along_dim(): input and indices should have the same "
            f"number of dimensions, but got {a.ndim} dimensions for input, and "
            f"{indices.ndim} dimensions for indices"
        ),
    )

    torch._check(
        utils.is_integer_dtype(indices.dtype),
        lambda: (
            "torch.take_along_dim(): dtype of indices should be int but got "
            f"{indices.dtype} instead"
        ),
    )

    if dim is None:
        return torch.gather(a.view(-1), 0, indices.view(-1))
    else:
        self_sizes = list(a.shape)
        self_sizes[dim] = indices.size(dim)
        broadcast_shape = utils.infer_size_shapes(self_sizes, indices.size())
        indices_broadcast = broadcast_to(indices, broadcast_shape)

        indices_sizes = list(indices.shape)
        indices_sizes[dim] = a.size(dim)
        broadcast_shape = utils.infer_size_shapes(indices_sizes, a.size())
        self_broadcast = broadcast_to(a, broadcast_shape)

        return torch.gather(self_broadcast, dim, indices_broadcast)


@out_wrapper()
def empty(
    *shape,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    requires_grad: bool = False,
    pin_memory: bool = False,
    memory_format: torch.memory_format = torch.contiguous_format,
) -> TensorLikeType:
    torch._check(
        memory_format != torch.preserve_format,
        lambda: "torch.empty: the Preserve memory format is not supported",
    )

    shape = utils.extract_shape_from_varargs(shape)

    if memory_format == torch.contiguous_format:
        strides = utils.make_contiguous_strides_for(shape)
    elif memory_format == torch.channels_last_3d:
        strides = utils.make_channels_last_3d_strides_for(shape)
    else:  # memory_format == torch.channels_last
        torch._check(
            memory_format == torch.channels_last,
            lambda: f"torch.empty: received an unknown memory format {memory_format}!",
        )
        strides = utils.make_channels_last_2d_strides_for(shape)

    return torch.empty_strided(
        shape,
        strides,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )


@out_wrapper()
def empty_permuted(
    shape,
    physical_layout,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    requires_grad: bool = False,
    pin_memory: bool = False,
) -> TensorLikeType:
    return prims.empty_permuted(
        shape,
        physical_layout,
        dtype=dtype,
        device=device,
        requires_grad=requires_grad,
    )


@register_decomposition(aten.new_empty)
@out_wrapper()
def new_empty(
    a: TensorLikeType,
    size: ShapeType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
) -> TensorLikeType:
    dtype = a.dtype if dtype is None else dtype
    layout = a.layout if layout is None else layout
    device = a.device if device is None else device

    return torch.empty(
        size,
        dtype=dtype,
        device=device,
        pin_memory=pin_memory,
        layout=layout,
    )


@register_decomposition(aten.new_empty_strided)
@out_wrapper()
def new_empty_strided(
    a: TensorLikeType,
    size: ShapeType,
    stride: StrideType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
) -> TensorLikeType:
    """
    Reference implementation of torch.Tensor.new_empty_strided
    """

    dtype = a.dtype if dtype is None else dtype
    layout = a.layout if layout is None else layout
    device = a.device if device is None else device

    return torch.empty_strided(
        size,
        stride,
        dtype=dtype,
        device=device,
        pin_memory=pin_memory,
        layout=layout,
    )


@register_decomposition(aten.zeros.default)
@out_wrapper()
def zeros(
    *size,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    size = utils.extract_shape_from_varargs(size)

    if dtype is None:
        dtype = torch.get_default_dtype()

    return torch.full(
        size,
        False if dtype == torch.bool else 0,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )


@register_decomposition(aten.new_zeros)
@out_wrapper()
def new_zeros(
    a: TensorLikeType,
    size: ShapeType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    dtype = a.dtype if dtype is None else dtype
    layout = a.layout if layout is None else layout
    device = a.device if device is None else device

    return torch.full(
        size,
        False if (dtype or a.dtype) == torch.bool else 0,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )


@register_decomposition(aten.ones.default)
@out_wrapper()
def ones(
    *size,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    size = utils.extract_shape_from_varargs(size)

    if dtype is None:
        dtype = torch.get_default_dtype()

    return torch.full(
        size,
        True if dtype == torch.bool else 1,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )


@register_decomposition(aten.new_ones)
@out_wrapper()
def new_ones(
    a: TensorLikeType,
    size: ShapeType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    dtype = a.dtype if dtype is None else dtype
    layout = a.layout if layout is None else layout
    device = a.device if device is None else device

    return torch.full(
        size,
        True if (dtype or a.dtype) == torch.bool else 1,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )


@register_decomposition(aten.new_full)
@out_wrapper()
def new_full(
    a: TensorLikeType,
    size: ShapeType,
    fill_value: NumberType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
) -> TensorLikeType:
    dtype = a.dtype if dtype is None else dtype
    layout = a.layout if layout is None else layout
    device = a.device if device is None else device

    return torch.full(
        size,
        fill_value,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
    )


@register_decomposition(aten.empty_like)
@out_wrapper()
def empty_like(
    a: TensorLikeType,
    *,
    dtype: Optional[torch.dtype] = None,
    device: Optional[DeviceLikeType] = None,
    layout: Optional[torch.layout] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
    memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:
    dtype = a.dtype if dtype is None else dtype
    layout = a.layout if layout is None else layout
    device = a.device if device is None else device

    if memory_format != torch.preserve_format:
        return torch.empty(
            a.shape,
            dtype=dtype,
            layout=layout,
            device=device,
            requires_grad=requires_grad,
            pin_memory=pin_memory,
            memory_format=memory_format,
        )

    # memory_format == torch.preserve_format
    logical_to_physical_perm = (
        utils.compute_elementwise_output_logical_to_physical_perm(a)
    )
    # identity perm is [2, 1, 0]
    return torch.empty_permuted(
        a.shape,
        logical_to_physical_perm,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )


@register_decomposition([aten.arange.start_step, aten.arange.start_out])
@out_wrapper()
def arange(
    start: NumberType = 0,
    end: Optional[NumberType] = None,
    step: NumberType = 1,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    utils.check_layout(layout)
    utils.check_pin_memory(pin_memory)
    device = torch.device(utils.device_or_default(device))

    assert not isinstance(start, complex)
    assert not isinstance(end, complex)
    assert not isinstance(step, complex)

    # Case: torch.arange(5)
    if end is None:
        end = start
        start = 0
    torch._check(step != 0, lambda: "step must be nonzero")
    if step > 0:
        torch._check(
            end >= start,
            lambda: "upper bound and lower bound inconsistent with step sign",
        )
    elif step < 0:
        torch._check(
            end <= start,
            lambda: "upper bound and lower bound inconsistent with step sign",
        )

    def is_finite(x):
        return not isinstance(x, FloatWithoutSymFloat) or math.isfinite(x)

    torch._check(
        is_finite(start) and is_finite(end),
        lambda: f"unsupported range: {start} -> {end}",
    )
    torch._check(
        is_finite(step),
        lambda: f"step must be finite but got {step}",
    )

    args = (start, end, step)
    integer_args = builtins.all(isinstance(arg, IntLike) for arg in args)

    if dtype is None:
        dtype = torch.int64 if integer_args else torch.get_default_dtype()

    is_integer = utils.is_integer_dtype(dtype)
    if is_integer or integer_args:
        xstart = sym_int(start)
        xend = sym_int(end)
        xstep = sym_int(step)

    # For int64 we truncate arguments to int before calculating length, but
    # other integral dtypes we don't. Weird... but needed to match ATen shapes.
    if dtype == torch.int64 or integer_args:
        # Uses floordiv to avoid ceil in inductor.
        sgn = bool(xstep > 0) - bool(xstep < 0)  # type: ignore[possibly-undefined]
        length = (xend - xstart + xstep - sgn) // xstep  # type: ignore[possibly-undefined]
    else:
        length = math.ceil((end - start) / step)

    if is_integer:
        return prims.iota(
            length,
            start=xstart,  # type: ignore[possibly-undefined]
            step=xstep,  # type: ignore[possibly-undefined]
            dtype=dtype,
            device=device,
            requires_grad=requires_grad,
        )

    index = prims.iota(
        length,
        start=0,
        step=1,
        dtype=torch.int64,
        device=device,
        requires_grad=False,
    )

    computation_dtype = (
        torch.long if integer_args else utils.get_acc_type(dtype, device)
    )
    index = _maybe_convert_to_dtype(index, computation_dtype)
    result = start + step * index
    result = _maybe_convert_to_dtype(result, dtype)

    if requires_grad:
        result.requires_grad_(True)
    return result


@register_decomposition(aten.lerp)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("start", "end", "weight"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def lerp(start: Tensor, end: Tensor, weight: Union[Tensor, NumberType]):
    inputs = [start, end]
    if isinstance(weight, Number):
        weight = start.new_full((), weight)  # type: ignore[arg-type]
    else:
        inputs.append(weight)
    assert isinstance(weight, Tensor)  # mypy
    # We implement it this way for numerical stability. We assume (in the stability optimisation)
    # that 0 <= weight <= 1. We take the abs to deal with complex numbers
    # We want to perform operations near zero, which is where floating points are most precise
    # thus, we perform the following optimisation:
    # If weight.abs() >= 0.5:
    #    return (1 - weight) * (start - end) + end
    mask = weight.abs() >= 0.5
    coeff = torch.where(mask, weight - 1, weight)
    base = torch.where(mask, end, start)
    output = coeff * (end - start) + base
    # make sure the decomposition output's stride is same as non-decomposition path.
    stride = utils.compute_elementwise_output_strides(*_maybe_broadcast(*inputs))
    if output.stride() != stride:
        output = prims.copy_strided(output, stride)

    return handle_noncontiguous_outputs(inputs, output)


@register_decomposition(aten.linspace)
@out_wrapper()
def linspace(
    start: Union[NumberType, TensorLikeType],
    end: Union[NumberType, TensorLikeType],
    steps: NumberType,
    *,
    dtype: Optional[torch.dtype] = None,
    device: Optional[DeviceLikeType] = None,
    layout: torch.layout = torch.strided,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    if isinstance(start, TensorLikeType):
        torch._check(
            start.dim() == 0,
            lambda: "linspace only supports 0-dimensional start and end tensors",
        )
        start = _maybe_convert_to_dtype(start, torch.float64)
    if isinstance(end, TensorLikeType):
        torch._check(
            end.dim() == 0,
            lambda: "linspace only supports 0-dimensional start and end tensors",
        )
        end = _maybe_convert_to_dtype(end, torch.float64)

    if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
        default_complex_dtype = utils.corresponding_complex_dtype(
            torch.get_default_dtype()
        )
        if dtype is None:
            dtype = default_complex_dtype
        else:
            torch._check(
                utils.is_complex_dtype(dtype),
                lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
            )
    else:
        dtype = dtype or torch.get_default_dtype()
    assert isinstance(dtype, torch.dtype)

    # steps does not participate in the computation of the dtype
    torch._check_type(
        isinstance(steps, IntLike),
        lambda: f"received an invalid combination of arguments - got \
({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
    )
    assert isinstance(steps, IntLike)  # for mypy
    torch._check(steps >= 0, lambda: "number of steps must be non-negative")

    factory_kwargs = {
        "layout": layout,
        "device": device,
        "pin_memory": pin_memory,
        "requires_grad": requires_grad,
    }
    if steps == 0:
        return torch.full((0,), 0, dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
    if steps == 1:
        if isinstance(start, TensorLikeType):
            empty_tensor = torch.empty((steps,), dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]
            return torch.ops.aten.copy.default(empty_tensor, start)
        else:
            return torch.full((steps,), start, dtype=dtype, **factory_kwargs)  # type: ignore[arg-type]

    # Perform in arange in int because some backends like ATen or Triton do not support all the dtypes
    rg = torch.arange(0, steps, **factory_kwargs)  # type: ignore[arg-type]

    # Small types need to be computed in higher precision as this is, at heart, an associative scan
    dtype_red = (
        torch.int64
        if (utils.is_boolean_dtype(dtype) or utils.is_integer_dtype(dtype))
        else dtype
    )
    computation_dtype, _ = utils.reduction_dtypes(
        rg, REDUCTION_OUTPUT_TYPE_KIND.SAME, dtype_red
    )
    cast_rg = partial(_maybe_convert_to_dtype, dtype=computation_dtype)

    # We implement torch.lerp without performing rg / (steps - 1) explicitly
    # With this we get out[0] == start, out[-1] == end
    step = (end - start) / (steps - 1)
    out = torch.where(
        rg < steps / 2,
        start + step * cast_rg(rg),  # type: ignore[arg-type,operator]
        end - step * cast_rg((steps - 1) - rg),  # type: ignore[arg-type,operator]
    )
    return _maybe_convert_to_dtype(out, dtype)  # type: ignore[return-value]


@register_decomposition(aten.logspace)
@out_wrapper()
def logspace(
    start: Union[NumberType, TensorLikeType],
    end: Union[NumberType, TensorLikeType],
    steps: NumberType,
    base: NumberType = 10,
    *,
    dtype: Optional[torch.dtype] = None,
    device: Optional[DeviceLikeType] = None,
    layout: torch.layout = torch.strided,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    if dtype is None:
        dtype = torch.get_default_dtype()

    # NB: NumPy doesn't have this cast
    if prims.utils.is_integer_dtype(dtype):
        if isinstance(start, FloatLike):
            start = sym_int(start)
        elif isinstance(start, TensorLikeType):
            torch._check(
                start.dim() == 0,
                lambda: "logspace only supports 0-dimensional start and end tensors",
            )
            start = _maybe_convert_to_dtype(start, dtype)
        if isinstance(end, FloatLike):
            end = sym_int(end)
        elif isinstance(end, TensorLikeType):
            torch._check(
                end.dim() == 0,
                lambda: "logspace only supports 0-dimensional start and end tensors",
            )
            end = _maybe_convert_to_dtype(end, dtype)

    if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)):
        default_complex_dtype = utils.corresponding_complex_dtype(
            torch.get_default_dtype()
        )
        dtype = default_complex_dtype
        _dtype = None  # torch.linspace will update the correct dtype
    else:
        _dtype = torch.float64

    assert not isinstance(base, complex)  # for mypy
    if base < 0:
        raise NotImplementedError
    ret = torch.linspace(  # type: ignore[misc]
        start,  # type: ignore[arg-type]
        end,  # type: ignore[arg-type]
        steps,  # type: ignore[arg-type]
        dtype=_dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )
    return _maybe_convert_to_dtype(torch.pow(base, ret), dtype)  # type: ignore[arg-type,return-value]


@overload
def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
    pass


@overload
def meshgrid(*tensors: TensorLikeType, indexing: str):
    pass


@register_decomposition(aten.meshgrid)  # type: ignore[misc]
def meshgrid(
    *tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]],
    indexing: str,
) -> list[TensorLikeType]:
    # This ref simultaneously handles two overloads (see stubs above)
    # The `indexing` argument is currently optional for torch.meshgrid, but we
    # plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
    if isinstance(tensors[0], (list, tuple)):
        assert len(tensors) == 1
        tensors = tuple(tensors[0])

    torch._check(
        builtins.all(isinstance(a, TensorLike) for a in tensors),
        lambda: "meshgrid expects its inputs to be tensors",
    )

    torch._check(len(tensors) > 0, lambda: "meshgrid expects a non-empty TensorList")

    for i in range(len(tensors) - 1):
        torch._check(
            tensors[i].dtype == tensors[i + 1].dtype,  # type: ignore[union-attr]
            lambda: "meshgrid expects all tensors to have the same dtype",
        )
        torch._check(
            tensors[i].device == tensors[i + 1].device,  # type: ignore[union-attr]
            lambda: "meshgrid expects all tensors to have the same device",
        )

    swap_first_and_second_tensors = False
    if indexing == "xy":
        swap_first_and_second_tensors = len(tensors) >= 2
        if swap_first_and_second_tensors:
            tensors = (tensors[1], tensors[0], *tensors[2:])
    else:
        torch._check(
            indexing == "ij",
            lambda: (
                'torch.meshgrid: indexing must be one of "xy" or "ij", '
                f"but received: {indexing}"
            ),
        )

    result_shape: list[int] = []
    for t in tensors:
        assert isinstance(t, TensorLike)  # mypy
        torch._check(
            t.ndim == 0 or t.ndim == 1,
            lambda: f"torch.meshgrid: Expected 0D or 1D tensor in the tensor list but got: {t}",
        )
        result_shape.append(t.numel())

    grids: list[TensorLikeType] = []
    for i, t in enumerate(tensors):
        assert isinstance(t, TensorLike)  # mypy
        if t.ndim == 0:
            t = t.view((1,))
        grids.append(prims.broadcast_in_dim(t, result_shape, (i,)))

    if swap_first_and_second_tensors:
        # Swap outputs if we originally swapped at the beginning
        grids[0], grids[1] = grids[1], grids[0]

    return grids


# CompositeImplicitAutograd - don't register decomp
def movedim(
    input: TensorLikeType,
    source: Union[int, DimsSequenceType],
    destination: Union[int, DimsSequenceType],
) -> TensorLikeType:
    """
    Reference implementation of torch.movedim
    """
    if type(source) is int:
        source = (source,)
    if type(destination) is int:
        destination = (destination,)

    # Converts to list to produce a compatible error message with core PyTorch,
    # which prints sequences in square brackets.
    torch._check(
        len(source) == len(destination),  # type: ignore[arg-type]
        lambda: (
            "movedim: Invalid source or destination dims: source "  # type: ignore[arg-type]
            f"({list(source)} dims) should contain the same number "  # type: ignore[arg-type]
            f"of dims as destination ({list(destination)} dims)"  # type: ignore[arg-type]
        ),
    )

    rank = input.ndim
    ss = tuple(utils.canonicalize_dims(rank=rank, indices=source))  # type: ignore[arg-type]
    ds = tuple(utils.canonicalize_dims(rank=rank, indices=destination))  # type: ignore[arg-type]

    sss = set(ss)
    dss = set(ds)

    # See above on why this converts to list in error messages.
    torch._check(
        len(ss) == len(sss),
        lambda: f"movedim: repeated dim in `source` ({list(source)})",  # type: ignore[arg-type]
    )
    torch._check(
        len(ds) == len(dss),
        lambda: f"movedim: repeated dim in `destination` ({list(destination)})",  # type: ignore[arg-type]
    )

    m = dict(zip(ds, ss))
    dims = []
    si = 0  # source index
    for di in range(rank):
        # check if the destination index is in the mapping
        s = m.get(di)
        if s is not None:
            # insert source index if found
            dims.append(s)
        else:
            # insert source index sequentially, skipping indices from the mapping
            while si in sss:
                si += 1
            dims.append(si)
            si += 1

    result = torch.permute(input, tuple(dims))

    return result


# NOTE: for convenience, shape can be a tuple of ints or a tuple containing a tuple of ints
@register_decomposition(aten.empty_strided)
@out_wrapper()
def empty_strided(
    shape: Union[ShapeType, tuple[ShapeType]],
    strides: StrideType,
    *,
    dtype: Optional[torch.dtype] = None,
    device: Optional[DeviceLikeType] = None,
    layout: torch.layout = torch.strided,
    requires_grad: bool = False,
    pin_memory: bool = False,
) -> TensorLikeType:
    # Layout == strided, pin_memory is False
    utils.check_layout(layout)
    utils.check_pin_memory(pin_memory)

    shape = utils.extract_shape_from_varargs(shape)
    dtype = torch.get_default_dtype() if dtype is None else dtype
    device = torch.device("cpu") if device is None else device

    return prims.empty_strided(
        shape,
        strides,
        dtype=dtype,
        device=device,
        requires_grad=requires_grad,
    )


@register_decomposition(aten.eye)
@out_wrapper()
def eye(
    n: int,
    m: Optional[int] = None,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,  # TODO: unused
) -> TensorLikeType:
    """
    Reference implementation of torch.eye
    """
    if m is None:
        m = n

    torch._check(n >= 0, lambda: f"n must be greater or equal to 0, got {n}")
    torch._check(m >= 0, lambda: f"m must be greater or equal to 0, got {m}")

    range_n = torch.arange(n, dtype=torch.int64, device=device, requires_grad=False)
    range_m = torch.arange(m, dtype=torch.int64, device=device, requires_grad=False)

    cond = range_n.unsqueeze(-1) == range_m
    if dtype is torch.bool:
        return cond
    else:
        one = torch.ones(
            (1,),
            dtype=dtype,
            layout=layout,
            device=device,
            pin_memory=pin_memory,
            requires_grad=False,
        )
        return torch.where(cond, one, 0)
    # TODO: Use requires_grad.  All refs taking the requires_grad kwarg must
    # return a leaf tensor.
    # result.requires_grad_(requires_grad)


@register_decomposition([aten.full.default, aten.full.out])
@out_wrapper()
def full(
    shape: ShapeType,
    fill_value: NumberType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
) -> TensorLikeType:
    utils.check_layout(layout)
    utils.check_pin_memory(pin_memory)

    dtype = dtype if dtype is not None else utils.type_to_dtype(type(fill_value))
    device = device if device is not None else torch.device("cpu")

    e = empty(
        shape,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
    )
    return torch.fill(e, fill_value)  # type: ignore[arg-type]


def full_like(
    a: TensorLikeType,
    fill_value: NumberType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
    memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:
    e = torch.empty_like(
        a,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
        memory_format=memory_format,
    )
    return fill(e, fill_value)


@register_decomposition(aten.zeros_like)
@out_wrapper()
def zeros_like(
    a: TensorLikeType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
    memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:
    return torch.full_like(
        a,
        False if (dtype or a.dtype) == torch.bool else 0,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
        memory_format=memory_format,
    )


@register_decomposition(aten.ones_like)
@out_wrapper()
def ones_like(
    a: TensorLikeType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: Optional[torch.layout] = None,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
    requires_grad: bool = False,
    memory_format: torch.memory_format = torch.preserve_format,
) -> TensorLikeType:
    return torch.full_like(
        a,
        True if (dtype or a.dtype) == torch.bool else 1,
        dtype=dtype,
        layout=layout,
        device=device,
        pin_memory=pin_memory,
        requires_grad=requires_grad,
        memory_format=memory_format,
    )


@register_decomposition(aten.randn.default)
@out_wrapper()
def randn(
    *shape,
    dtype: Optional[torch.dtype] = None,
    device: Optional[DeviceLikeType] = None,
    layout: Optional[torch.layout] = None,
    requires_grad: bool = False,
    pin_memory: bool = False,
) -> TensorLikeType:
    utils.check_pin_memory(pin_memory)

    shape_ = utils.extract_shape_from_varargs(shape)

    dtype = utils.dtype_or_default(dtype)
    device = utils.device_or_default(device)

    return prims.normal(
        shape_,
        mean=0.0,
        std=1.0,
        dtype=dtype,
        device=device,
        requires_grad=requires_grad,
    )


def scalar_tensor(
    a: NumberType,
    *,
    dtype: Optional[torch.dtype] = None,
    layout: torch.layout = torch.strided,
    device: Optional[DeviceLikeType] = None,
    pin_memory: bool = False,
) -> TensorLikeType:
    utils.check_layout(layout)
    utils.check_pin_memory(pin_memory)
    dtype = dtype if dtype is not None else utils.type_to_dtype(type(a))
    device = device if device is not None else torch.device("cpu")
    return prims.scalar_tensor(a, dtype=dtype, device=device)


#
# Randomness References
#


def _uniform_helper(
    shape: ShapeType,
    low: Union[bool, int, float] = 0.0,
    high: Union[bool, int, float] = 1.0,
    *,
    dtype: torch.dtype,
    device: DeviceLikeType,
) -> TensorLikeType:
    utils.validate_shape(shape)

    assert isinstance(low, Number)
    assert isinstance(high, Number)
    low = sym_float(low)
    high = sym_float(high)

    assert isinstance(dtype, torch.dtype)
    device = utils.canonicalize_device(device)

    return prims._uniform_helper(shape, low=low, high=high, dtype=dtype, device=device)


@register_decomposition(aten.masked_fill)
@out_wrapper()
def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType):
    python_type = utils.dtype_to_type(a.dtype)
    if isinstance(value, Number):
        value_type = type(value)
    else:
        # NOTE: Could not use value = item(value) as it resulted in
        # RuntimeError: Cannot cast FakeTensor(cpu) to number
        value_ndim = value.ndim
        torch._check(
            value_ndim == 0,
            lambda: f"only supports a 0-dimensional value tensor, but got tensor with {value_ndim} dimension",
        )
        # `masked_fill` allows cpu scalar to be moved to cuda, xpu and hpu but not otherwise.
        is_cpu_scalar = (
            a.device.type
            in ["cuda", "xpu", torch._C._get_privateuse1_backend_name(), "hpu"]
            and value.device.type == "cpu"
        )
        torch._check(
            is_cpu_scalar or value.device == a.device,
            lambda: "Expected `value` to be on same device as `a`",
        )
        value_type = utils.dtype_to_type(value.dtype)

    if value_type is complex:
        # only downcasting from complex to lower type is not allowed.
        # We allow casting `value` to lower type for other case
        # Eg. float -> int.
        # Ref: https://github.com/pytorch/pytorch/issues/79195
        torch._check(
            utils.is_weakly_lesser_type(value_type, python_type),
            lambda: f"could not convert to type {python_type} without overflow",
        )

    # Since `where` allows type-promotion,
    # cast value to correct type before passing to `where`
    value = _maybe_convert_to_dtype(value, a.dtype)
    r = torch.where(mask, value, a)  # type: ignore[arg-type]

    # aten.mask_fill always return a new contiguous tensor
    # contiguous() is needed to correctly model the output stride
    return r.contiguous()


@register_decomposition(aten.masked_fill_)
def masked_fill_(
    a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLikeType
) -> TensorLikeType:
    b = torch.masked_fill(a, mask, value)  # type: ignore[arg-type]
    a.copy_(b)
    return a


# CompositeImplicitAutograd - don't register decomp
def allclose(
    a: TensorLikeType,
    b: TensorLikeType,
    rtol: float = 1e-05,
    atol: float = 1e-08,
    equal_nan: bool = False,
) -> bool:
    """
    Reference implementation of torch.allclose
    """
    _check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)

    return bool(
        torch.all(torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).item()
    )


def equal(a: TensorLikeType, b: TensorLikeType) -> bool:
    utils.check_same_device(a, b, allow_cpu_scalar_tensors=False)
    utils.check_same_dtype(a, b)

    # Shape check
    if a.ndim != b.ndim:
        return False

    for x, y in zip(a.shape, b.shape):
        if x != y:
            return False

    # Short-circuits if there are no elements to validate
    if a.numel() == 0:
        return True

    return item(all(eq(a, b)))  # type: ignore[return-value]


@register_decomposition(aten.norm)
@out_wrapper(exact_dtype=True)
def norm(
    input: TensorLikeType,
    p: Optional[Union[float, str]] = "fro",
    dim: Optional[DimsType] = None,
    keepdim: bool = False,
    *,
    dtype: Optional[torch.dtype] = None,
) -> TensorLikeType:
    # In these cases we compute the "Frobenius norm"
    if (
        p == "fro" and (dim is None or isinstance(dim, Dim) or len(dim) <= 2)
    ) or p is None:
        p = 2
    if isinstance(dim, Dim):
        dim = [dim]
    if isinstance(p, str):
        # Here we either call the nuclear norm, or we call matrix_norm with some arguments
        # that will throw an error
        if dim is None:
            dim = tuple(range(input.ndim))
        return torch.linalg.matrix_norm(input, p, dim, keepdim, dtype=dtype)
    else:
        return torch.linalg.vector_norm(input, p, dim, keepdim, dtype=dtype)


@register_decomposition(aten.trace)
@out_wrapper()
def trace(self: TensorLikeType) -> TensorLikeType:
    torch._check(
        self.ndim == 2, lambda: "expected a matrix, but got tensor with dim {self.ndim}"
    )
    return torch.sum(torch.diag(self, 0))


def _make_r_binary_op(base_op):
    def rop(
        a: Union[TensorLikeType, NumberType],
        b: Union[TensorLikeType, NumberType],
    ) -> TensorLikeType:
        return base_op(b, a)

    return rop


rtruediv = _make_r_binary_op(true_divide)
rfloordiv = _make_r_binary_op(floor_divide)
rpow = _make_r_binary_op(pow)


@register_decomposition(aten.triu)
@out_wrapper()
def triu(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
    torch._check(
        a.ndim >= 2, lambda: "triu: input tensor must have at least 2 dimensions"
    )
    h, w = a.shape[-2:]
    mask = (
        torch.arange(w, device=a.device).unsqueeze(-2)
        - torch.arange(h, device=a.device).unsqueeze(-1)
    ) >= diagonal

    # aten.triu always returns a new contiguous tensor
    # contiguous() is needed to correctly model the output stride
    return utils.mask_tensor(mask, a).contiguous()


@register_decomposition(aten.tril)
@out_wrapper()
def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
    torch._check(
        a.ndim >= 2, lambda: "tril: input tensor must have at least 2 dimensions"
    )
    h, w = a.shape[-2:]
    mask = (
        torch.arange(w, device=a.device).unsqueeze(-2)
        - torch.arange(h, device=a.device).unsqueeze(-1)
    ) <= diagonal

    # aten.tril always returns a new contiguous tensor
    # contiguous() is needed to correctly model the output stride
    return utils.mask_tensor(mask, a).contiguous()


# This is based on get_tril_size in aten/src/ATen/native/TensorFactories.h
# The components of the matrix that belong to the lower triangle with offset
# form a pentagon that can be broken down into a top trapezoid and a bottom
# rectangle. For the implementation of tril_indices, we need the sizes of
# both of these, as well as the length of the top side of the trapezoid.
def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
    if row == 0 or col == 0:
        return 0, 0, 0

    m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
    m_last_row = max(0, min(col, row + offset))
    n_row_all = max(0, min(row, row + offset))
    n_row_trapezoid = m_last_row - m_first_row + 1

    # Number of elements in top trapezoid
    trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
    # Number of elements in bottom rectangle
    diff_row = n_row_all - n_row_trapezoid
    rectangle_size = max(0, diff_row * col)

    return trapezoid_size, rectangle_size, m_first_row


def _trilu_checks(
    name: str,
    row: int,
    col: int,
    dtype: torch.dtype,
    layout: torch.layout,
    pin_memory: bool,
):
    torch._check(row >= 0, lambda: f"row must be non-negative, got {row}")
    torch._check(col >= 0, lambda: f"col must be non-negative, got {col}")
    torch._check(
        dtype in (torch.int32, torch.int64),
        lambda: f"\"{name}\" not implemented for '{dtype}'",
    )


# This is based on tril_indices_cuda in aten/src/ATen/native/cuda/TensorFactories.cu
@register_decomposition(aten.tril_indices)
@out_wrapper()
def tril_indices(
    row: int,
    col: int,
    offset: int = 0,
    *,
    dtype: torch.dtype = torch.long,
    layout: torch.layout = torch.strided,
    device: DeviceLikeType = "cpu",
    pin_memory: bool = False,
) -> TensorLikeType:
    _trilu_checks("tril_indices", row, col, dtype, layout, pin_memory)

    trapezoid_size, rectangle_size, m_first_row = _get_tril_sizes(row, col, offset)
    row_offset = max(0, -offset)

    arange_kw = partial(
        torch.arange, layout=layout, device=device, pin_memory=pin_memory
    )

    # first we do the indices for top trapezoid
    xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
    b = m_first_row - 0.5
    row_inds1 = torch.floor(-b + torch.sqrt(b * b + 2 * xs1))
    col_inds1 = torch.floor(xs1 - (2 * m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
    row_inds1 = _maybe_convert_to_dtype(row_inds1 + row_offset, dtype)
    col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)

    # then bottom rectangle
    xs2 = arange_kw(0, rectangle_size, dtype=dtype)
    row_inds2 = xs2 // col + (col - m_first_row + 1 + row_offset)
    col_inds2 = xs2 % col

    return torch.stack(
        (torch.cat((row_inds1, row_inds2)), torch.cat((col_inds1, col_inds2)))
    )


# Similar to _get_tril_sizes above, but here there is a top trapezoid and
# a bottom rectangle instead. Note that you can't reduce this to
# _get_tril_sizes(col, row, -offset) because that would correspond to
# decomposing into a left trapezoid and right rectangle.
def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
    if row == 0 or col == 0:
        return 0, 0, 0

    m_first_row = max(0, col - offset) if offset > 0 else col

    # Number of elements in top rectangle
    rectangle_size = max(0, min(row, -offset) * col)

    # Number of elements in bottom trapezoid
    trapezoid_size_tril, rectangle_size_tril, _ = _get_tril_sizes(row, col, offset - 1)
    triu_size = row * col - (trapezoid_size_tril + rectangle_size_tril)
    trapezoid_size = triu_size - rectangle_size

    return trapezoid_size, rectangle_size, m_first_row


@register_decomposition(aten.triu_indices)
@out_wrapper()
def triu_indices(
    row: int,
    col: int,
    offset: int = 0,
    *,
    dtype: torch.dtype = torch.long,
    layout: torch.layout = torch.strided,
    device: DeviceLikeType = "cpu",
    pin_memory: bool = False,
) -> TensorLikeType:
    _trilu_checks("triu_indices", row, col, dtype, layout, pin_memory)

    trapezoid_size, rectangle_size, m_first_row = _get_triu_sizes(row, col, offset)
    col_offset = max(0, offset)

    arange_kw = partial(
        torch.arange, layout=layout, device=device, pin_memory=pin_memory
    )

    # indices for top rectangle
    xs2 = arange_kw(0, rectangle_size, dtype=dtype)
    row_inds2 = xs2 // col
    col_inds2 = xs2 % col

    # bottom trapezoid
    xs1 = arange_kw(0, trapezoid_size, dtype=torch.float64)
    b = -0.5 - m_first_row
    row_inds1 = torch.floor(-b - torch.sqrt(b * b - 2 * xs1))
    col_inds1 = torch.floor(xs1 - ((2 * m_first_row - 1 - row_inds1) * row_inds1) * 0.5)
    row_inds1 = _maybe_convert_to_dtype(row_inds1, dtype)
    col_inds1 = _maybe_convert_to_dtype(col_inds1, dtype)

    if col:
        row_inds1 = row_inds1 + (rectangle_size // col)
    col_inds1 = col_inds1 + col_offset

    return torch.stack(
        (torch.cat((row_inds2, row_inds1)), torch.cat((col_inds2, col_inds1)))
    )


@register_decomposition(aten.bucketize)
@out_wrapper(exact_dtype=True)
def bucketize(
    a: TensorOrNumberLikeType,
    boundaries: TensorLikeType,
    *,
    out_int32: bool = False,
    right: bool = False,
):
    torch._check(
        boundaries.dim() == 1,
        lambda: f"boundaries tensor must be 1 dimension but got dim({boundaries.dim()})",
    )

    a = a if isinstance(a, torch.Tensor) else torch.tensor(a)
    out_dtype = torch.int32 if out_int32 else torch.int64
    n_boundaries = boundaries.shape[-1]
    if n_boundaries == 0:
        return torch.zeros_like(a)
    # We are trying to find the bucket (defined by pairs of consecutive elements of `boundaries`)
    # each element of `a` belongs to. We use binary search to achieve logarithimic complexity,
    # but each step of the search is done "in parallel" over all elements of `a`
    # can't use int32 as indexes, so we have to do all computations with int64 and convert at the end
    start = torch.zeros(a.shape, device=a.device, dtype=torch.int64)
    end = start + n_boundaries
    # Max depth of the binary search
    # Since we can't break out of the loop at different points for different elements of a,
    # we just do the max amount of iterations that binary search requires and add condition
    # tensor (cond_update below) to stop updating once the search terminates

    # For first iteration through loop we can skip some checks, we have separate implementation
    mid = start + (end - start) // 2
    mid_val = boundaries[mid]
    if right:
        cond_mid = mid_val > a
    else:
        cond_mid = mid_val >= a
    start = torch.where(cond_mid, start, mid + 1)

    if n_boundaries > 1:
        cond_update = torch.ones_like(a, dtype=torch.bool)
        niters = int(math.log2(n_boundaries))
        for _ in range(niters):
            end = torch.where(cond_mid & cond_update, mid, end)
            cond_update = start < end
            # start might end up pointing to 1 past the end, we guard against that
            mid = torch.where(cond_update, start + (end - start) // 2, 0)
            mid_val = boundaries[mid]
            # If right is true, the buckets are closed on the *left*
            # (i.e., we are doing the equivalent of std::upper_bound in C++)
            # Otherwise they are closed on the right (std::lower_bound)
            if right:
                cond_mid = mid_val > a
            else:
                cond_mid = mid_val >= a
            start = torch.where((~cond_mid) & cond_update, mid + 1, start)

    return start.to(dtype=out_dtype)


@register_decomposition(aten.cauchy)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def cauchy(self, median=0, sigma=1, generator=None):
    assert generator is None
    torch._check(
        not utils.is_complex_dtype(self.dtype)
        and not utils.is_integer_dtype(self.dtype)
        and not utils.is_boolean_dtype(self.dtype),
        lambda: f"Cauchy distribution is a continuous probability distribution. \
        dtype must be a floating point but you specified {self.dtype}",
    )
    torch._check(
        sigma > 0.0,
        lambda: f"cauchy_ expects sigma > 0.0, but found sigma={sigma}",
    )
    return median + sigma * torch.tan(math.pi * (torch.rand_like(self) - 0.5))


@register_decomposition(aten.exponential)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def exponential(self, rate=1, generator=None):
    assert generator is None
    torch._check(
        not utils.is_complex_dtype(self.dtype)
        and not utils.is_integer_dtype(self.dtype)
        and not utils.is_boolean_dtype(self.dtype),
        lambda: f"Exponential distribution is a continuous probability distribution. \
        dtype must be a floating point but you specified {self.dtype}",
    )
    torch._check(
        rate > 0.0,
        lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}",
    )

    uniform_val = torch.rand_like(self)

    # copying numerics of transformation::exponential see comment:
    # curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
    # we need log to be not 0, and not underflow when converted to half
    # fast __logf approximation can underflow, so set log to -epsilon/2 for 1 or close to 1 args
    epsilon = torch.finfo(uniform_val.dtype).eps / 2
    condition = uniform_val >= 1.0 - epsilon
    log_uniform = torch.where(condition, -epsilon, torch.log(uniform_val))

    return -1 / rate * log_uniform


@register_decomposition(aten.geometric)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def geometric(self, p, generator=None):
    assert generator is None
    # TODO: fix inductor rand_like for integer, bool dtypes
    torch._check(
        not utils.is_complex_dtype(self.dtype)
        and not utils.is_boolean_dtype(self.dtype),
        lambda: f"geometric not implemented for {self.dtype}",
    )
    torch._check(
        0 < p and p < 1,
        lambda: f"geometric_ expects p to be in (0, 1), but got p={p}",
    )
    return torch.floor(torch.log1p(-torch.rand_like(self)) / math.log1p(-p)) + 1


@register_decomposition(aten.log_normal)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self",),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def log_normal(self, mean=1, std=2, generator=None):
    assert generator is None
    torch._check(
        not utils.is_complex_dtype(self.dtype)
        and not utils.is_integer_dtype(self.dtype)
        and not utils.is_boolean_dtype(self.dtype),
        lambda: f"log_normal not implemented for {self.dtype}",
    )
    torch._check(
        0 < std,
        lambda: f"log_normal_ expects std > 0.0, but found std={std}",
    )
    return torch.exp(std * torch.randn_like(self) + mean)


# TODO: add support for functionalization aten.normal_functional
# NOTE: the device and dtype will be ignored when shape is None
@register_decomposition(aten.normal)
@out_wrapper()
@elementwise_type_promotion_wrapper(
    type_promoting_args=(
        "mean",
        "std",
    ),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def normal(
    mean=0,
    std=1,
    size=None,
    *,
    generator=None,
    dtype=None,
    layout=None,
    device=None,
    pin_memory=None,
):
    assert layout is None or layout == torch.strided

    if not isinstance(std, TensorLike):
        torch._check(
            std >= 0, lambda: f"normal expects std >= 0.0, but found std {std}"
        )

    if size is None:
        tensors = tuple(t for t in (mean, std) if isinstance(t, TensorLike))
        torch._check(
            len(tensors) > 0,
            lambda: "normal expects that either mean or std is a tensor, or size is defined",
        )
        torch._check(
            layout is None and pin_memory is None,
            lambda: "Cannot pass layout, or pin_memory without size",
        )

        size = _broadcast_shapes(*(t.shape for t in tensors))
        dtype = tensors[0].dtype
        device = tensors[0].device
    else:
        torch._check(
            not isinstance(mean, TensorLike) and not isinstance(std, TensorLike),
            lambda: "normal expects mean and std to be scalars when size is defined",
        )
        dtype = torch.get_default_dtype() if dtype is None else dtype
        device = torch.device("cpu") if device is None else device

    normal_samples = prims.normal(
        size,
        mean=0.0,
        std=1.0,
        dtype=dtype,
        device=device,
        requires_grad=False,
        generator=generator,
    )
    return std * normal_samples + mean


@register_decomposition(aten.normal_)
def normal_(self, mean=0, std=1, *, generator=None):
    return normal(mean, std, self.shape, out=self, generator=generator)


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def rad2deg(self: TensorLikeType):
    torch._check(
        not utils.is_complex_dtype(self.dtype),
        lambda: "rad2deg is not supported for complex tensors.",
    )
    M_180_PI = 57.295779513082320876798154814105170332405472466564
    return self * M_180_PI


@_make_elementwise_unary_reference(ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
def deg2rad(self: TensorLikeType):
    torch._check(
        not utils.is_complex_dtype(self.dtype),
        lambda: "deg2rad is not supported for complex tensors.",
    )
    M_PI_180 = 0.017453292519943295769236907684886127134428718885417
    return self * M_PI_180


@register_decomposition(aten.count_nonzero)
@out_wrapper()
def count_nonzero(self, dim: Optional[DimsType] = None):
    return (self != 0).sum(dim)


def _dot_check(self, other):
    torch._check(
        self.dim() == 1 and other.dim() == 1,
        lambda: f"1D tensors expected, but got {self.dim()}D and {other.dim()}D tensors",
    )

    torch._check(
        self.dtype == other.dtype,
        lambda: "dot : expected both vectors to have same dtype, but found "
        f"{self.dtype} and {other.dtype}",
    )

    def numel_error():
        return (
            f"inconsistent tensor size, expected tensor [{self.numel()}] and src [{other.numel()}] to have the"
            f"same number of elements, but got {self.numel()} and {other.numel()} elements respectively"
        )

    torch._check(self.numel() == other.numel(), numel_error)


def _dot_check_wrapper(fn):
    @wraps(fn)
    def wrapper(self, other):
        _dot_check(self, other)
        return fn(self, other)

    return wrapper


@register_decomposition(aten.dot)
@out_wrapper()
@_dot_check_wrapper
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self", "other"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def dot(self, other):
    if self.is_complex():
        if self.is_conj():
            if other.is_conj():
                return torch.dot(self.conj(), other.conj()).conj()
            else:
                return torch.vdot(self.conj(), other)
        elif other.is_conj():
            return torch.vdot(other.conj(), self)

    return (self * other).sum()


@register_decomposition(aten.vdot)
@out_wrapper()
@_dot_check_wrapper
@elementwise_type_promotion_wrapper(
    type_promoting_args=("self", "other"),
    type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def vdot(self, other):
    if not self.is_complex():
        return torch.dot(self, other)

    if self.is_conj():
        if other.is_conj():
            return torch.vdot(other.conj(), self.conj())
        else:
            return torch.dot(self.conj(), other)
    elif other.is_conj():
        return torch.dot(self, other.conj()).conj()

    # The decomposition fails if you do self.conj()... not sure why
    return (self.conj_physical() * other).sum()


@register_decomposition(aten.select_scatter)
@out_wrapper()
def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int):
    dim = utils.canonicalize_dim(x.ndim, dim)
    mask_shape = [1] * x.ndim
    mask_shape[dim] = -1
    if index < 0:
        index = index + x.shape[dim]
    mask = torch.arange(x.shape[dim], device=x.device).view(mask_shape) == index
    src = torch.unsqueeze(src, dim).expand(x.shape)
    return torch.where(mask, src, x)


# inplace
abs_ = _make_inplace(abs)
acos_ = _make_inplace(acos)
acosh_ = _make_inplace(acosh)
add_ = _make_inplace(add)
addcmul_ = _make_inplace(addcmul)
addcdiv_ = _make_inplace(addcdiv)
asin_ = _make_inplace(asin)
asinh_ = _make_inplace(asinh)
atan_ = _make_inplace(atan)
atanh_ = _make_inplace(atanh)
atan2_ = _make_inplace(atan2)
bitwise_and_ = _make_inplace(bitwise_and)
bitwise_left_shift_ = _make_inplace(bitwise_left_shift)
bitwise_not_ = _make_inplace(bitwise_not)
bitwise_or_ = _make_inplace(bitwise_or)
bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
bitwise_xor_ = _make_inplace(bitwise_xor)
ceil_ = _make_inplace(ceil)
clamp_ = _make_inplace(clamp)
clamp_min_ = _make_inplace(clamp_min)
clamp_max_ = _make_inplace(clamp_max)
conj_physical_ = _make_inplace(conj_physical)
copysign_ = _make_inplace(copysign)
cos_ = _make_inplace(cos)
cosh_ = _make_inplace(cosh)
cumsum_ = _make_inplace(cumsum)
cumprod_ = _make_inplace(cumprod)
deg2rad_ = _make_inplace(deg2rad)
digamma_ = _make_inplace(digamma)
div_ = _make_inplace(div)
eq_ = _make_inplace(eq)
erf_ = _make_inplace(erf)
erfc_ = _make_inplace(erfc)
erfinv_ = _make_inplace(erfinv)
exp_ = _make_inplace(exp)
exp2_ = _make_inplace(exp2)
expm1_ = _make_inplace(expm1)
float_power_ = _make_inplace(float_power)
floor_ = _make_inplace(floor)
floor_divide_ = _make_inplace(floor_divide)
fmod_ = _make_inplace(fmod)
frac_ = _make_inplace(frac)
gcd_ = _make_inplace(gcd)
ge_ = _make_inplace(ge)
gt_ = _make_inplace(gt)
heaviside_ = _make_inplace(heaviside)
hypot_ = _make_inplace(hypot)
igamma_ = _make_inplace(igamma)
igammac_ = _make_inplace(igammac)
i0_ = _make_inplace(i0)
lcm_ = _make_inplace(lcm)
le_ = _make_inplace(le)
lerp_ = _make_inplace(lerp)
lgamma_ = _make_inplace(lgamma)
log10_ = _make_inplace(log10)
log1p_ = _make_inplace(log1p)
log2_ = _make_inplace(log2)
log_ = _make_inplace(log)
logical_and_ = _make_inplace(logical_and)
logical_not_ = _make_inplace(logical_not)
logical_or_ = _make_inplace(logical_or)
logical_xor_ = _make_inplace(logical_xor)
lt_ = _make_inplace(lt)
mul_ = _make_inplace(mul)
mvlgamma_ = _make_inplace(mvlgamma)
nan_to_num_ = _make_inplace(nan_to_num)
ne_ = _make_inplace(ne)
neg_ = _make_inplace(neg)
nextafter_ = _make_inplace(nextafter)
pow_ = _make_inplace(pow)
rad2deg_ = _make_inplace(rad2deg)
reciprocal_ = _make_inplace(reciprocal)
remainder_ = _make_inplace(remainder)
rsqrt_ = _make_inplace(rsqrt)
sgn_ = _make_inplace(sgn)
sigmoid_ = _make_inplace(sigmoid)
sign_ = _make_inplace(sign)
sin_ = _make_inplace(sin)
sinc_ = _make_inplace(sinc)
sinh_ = _make_inplace(sinh)
sqrt_ = _make_inplace(sqrt)
square_ = _make_inplace(square)
sub_ = _make_inplace(sub)
tan_ = _make_inplace(tan)
tanh_ = _make_inplace(tanh)
tril_ = _make_inplace(tril)
triu_ = _make_inplace(triu)
true_divide_ = _make_inplace(true_divide)
trunc_ = _make_inplace(trunc)
xlogy_ = _make_inplace(xlogy)
cauchy_ = _make_inplace(cauchy)
exponential_ = _make_inplace(exponential)
geometric_ = _make_inplace(geometric)
log_normal_ = _make_inplace(log_normal)
zero_ = _make_inplace(zero)

alias_copy = _make_copy_from_view(aten.alias)
as_strided_copy = _make_copy_from_view(aten.as_strided)
diagonal_copy = _make_copy_from_view(aten.diagonal)
expand_copy = _make_copy_from_view(aten.expand)
# TODO: This must return a sparse tensor if the input is sparse, but refs have
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow)
squeeze_copy = _make_copy_from_view(aten.squeeze)
permute_copy = _make_copy_from_view(aten.permute)
t_copy = _make_copy_from_view(aten.t)
transpose_copy = _make_copy_from_view(aten.transpose)
unbind_copy = _make_copy_from_view(aten.unbind)
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
view_copy = _make_copy_from_view(aten.view)


# xref: isStorage in torch/csrc/DynamicTypes.cpp
def _isStorage(obj):
    return isinstance(obj, (torch.TypedStorage, torch.UntypedStorage))


# xref: compute_sizes in torch/csrc/utils/tensor_new.cpp
def _compute_sizes(seq, scalar_type):
    MAX_DIMS = 128
    is_storage = _isStorage(seq)
    sizes = []
    # TODO: this is inaccurate, we actually test PySequence_Check
    while isinstance(seq, (list, tuple)):
        length = len(seq)
        if is_storage:
            length //= scalar_type.itemsize
        sizes.append(length)
        if len(sizes) > MAX_DIMS:
            raise ValueError(f"too many dimensions '{type(seq).__name__}'")
        if length == 0:
            break
        try:
            handle = seq[0]
        except Exception:
            raise ValueError(  # noqa: B904
                f"could not determine the shape of object type '{type(seq).__name__}'"
            )
        seq = handle

    return sizes


# xref: infer_scalar_type in torch/csrc/utils/tensor_new.cpp
def _infer_scalar_type(obj):
    if isinstance(obj, FloatLike):
        return torch.get_default_dtype()
    if isinstance(obj, IntLike) and not isinstance(obj, bool):  # careful!
        return torch.int64
    if isinstance(obj, BoolLike):
        return torch.bool
    if isinstance(obj, complex):
        default_dtype = torch.get_default_dtype()
        if default_dtype is torch.float:
            return torch.cfloat
        elif default_dtype is torch.double:
            return torch.cdouble
        elif default_dtype is torch.half:
            return torch.chalf
        else:
            raise RuntimeError("invalid default scalar type for complex")
    if isinstance(obj, torch.Tensor):
        return obj.dtype
    if isinstance(obj, str):
        raise TypeError(f"new(): invalid data type '{type(obj).__name__}'")
    # TODO: this is inaccurate, we actually test PySequence_Check
    if isinstance(obj, (list, tuple)):
        scalarType = None
        length = len(obj)
        # match NumPy semantics, except use default tensor type instead of
        # double.
        if length == 0:
            return torch.get_default_dtype()
        for i in range(length):
            cur_item = obj[i]
            # TODO: test this
            """
            if cur_item is obj:
                raise TypeError("new(): self-referential lists are incompatible")
            """
            item_scalarType = _infer_scalar_type(cur_item)  # recurse!
            if scalarType is not None:
                scalarType = torch.promote_types(scalarType, item_scalarType)
            else:
                scalarType = item_scalarType
            if scalarType is torch.cdouble:
                # this won't change (unless we hit undefined, but that will
                # fail later)
                return scalarType
        return scalarType
    raise RuntimeError(f"Could not infer dtype of {type(obj).__name__}")


# Analogous to recursive_store
# xref: recursive_store in torch/csrc/utils/tensor_new.cpp
def _recursive_build(
    scalarType: torch.dtype, obj: Union[TensorOrNumberLikeType, TensorSequenceType]
):
    if isinstance(obj, Tensor) and obj.numel() == 1:
        return obj.detach().to(dtype=scalarType, device="cpu", copy=True).view(())
    elif isinstance(obj, Tensor):
        # It is invalid to call ".tensor([...])" with a non-scalar tensor in eager mode
        # >>> torch.tensor([torch.randn(2)])
        # ValueError: only one element tensors can be converted to Python scalars
        #
        # But it is possible with a NumPy array
        # >>> torch.tensor([np.random.uniform(size=(2,))]).shape
        # torch.Size([1, 2])
        return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
    elif isinstance(obj, Number):
        return torch.scalar_tensor(obj, dtype=scalarType)

    # seq can be a list of tensors
    seq = obj
    return (
        torch.empty(0)
        if not seq
        else torch.stack([_recursive_build(scalarType, item) for item in seq])
    )


# xref: internal_new_from_data in torch/csrc/utils/tensor_new.cpp
def _internal_new_from_data(
    options,
    scalar_type,
    device_opt,
    data,
    copy_variables,
    copy_numpy,
    type_inference,
    pin_memory=False,
):
    if isinstance(data, torch.Tensor):
        torch._check(
            not pin_memory, lambda: "Can't pin tensor constructed from a variable"
        )
        var = data
        if copy_variables:
            var = var.detach()
        inferred_scalar_type = var.dtype if type_inference else scalar_type
        device = device_opt if device_opt is not None else var.device
        return var.to(
            device=device,
            dtype=inferred_scalar_type,
            non_blocking=False,
            copy=copy_variables,
        )

    # TODO
    if hasattr(data, "__cuda_array_interface__"):
        return NotImplemented

    # TODO: test for numpy input with PyArray_Check

    device = device_opt if device_opt is not None else options["device"]
    inferred_scalar_type = _infer_scalar_type(data) if type_inference else scalar_type

    # NB: Don't need to avoid tracing, as we aren't going to do any manual
    # pointer filling tricks
    if _isStorage(data):
        return NotImplemented
    else:
        if torch.device(device).type == "meta":
            return NotImplemented

        # In the C implementation, we would directly start poking the memory
        # of a freshly allocated CPU tensor.  Here, we're going to do an
        # alternate, heinously slow implementation: turn each individual
        # scalar into a tensor, and then repeatedly cat them together
        tensor = _recursive_build(inferred_scalar_type, data)

        tensor = tensor.to(device, inferred_scalar_type, non_blocking=False, copy=False)

    # NB: lift_fresh is not needed, because we built the tensor from scalars
    # guaranteeing a fresh tensor in this case
    return tensor


# xref: tensor_ctor in torch/csrc/utils/tensor_new.cpp
def tensor(data, *, dtype=None, device=None, pin_memory=False, requires_grad=False):
    # TODO (or not): support names kwarg
    if isinstance(data, torch.Tensor):
        warnings.warn(
            "To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() "
            "or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor)",
            UserWarning,
            stacklevel=2,
        )
    type_inference = dtype is None
    new_tensor = _internal_new_from_data(
        # device="cpu" because that's what you get with torch.tensor(2) no
        # device by default
        {"device": "cpu"},  # TODO: use torch.get_default_tensor_type
        dtype if dtype is not None else torch.get_default_dtype(),
        device,
        data,
        copy_variables=True,
        copy_numpy=True,
        type_inference=type_inference,
        pin_memory=pin_memory,
    )
    new_tensor.detach_()
    if requires_grad:
        new_tensor.requires_grad_(requires_grad)
    return new_tensor


# Views
# We can't model these as above, as the pattern of doing `op(a, out=a)` does not work for a view function
# given that it does not reshape the input (it just copies the result into it)

# squeeze_ = _make_inplace(squeeze)
# t_ = _make_inplace(t)
# transpose_ = _make_inplace(transpose)
# unsqueeze_ = _make_inplace(unsqueeze)


import torch._refs._conversions
import torch._refs.fft
import torch._refs.linalg
import torch._refs.nn.functional
import torch._refs.special
