from ..runtime.jit import jit
from . import core as tl
from . import math

N_ROUNDS_DEFAULT = 10  # Default number of rounds for philox

# -------------------
# randint
# -------------------


@jit
def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
    """
    if c0.dtype == tl.uint32:
        PHILOX_KEY_A: tl.constexpr = 0x9E3779B9
        PHILOX_KEY_B: tl.constexpr = 0xBB67AE85
        PHILOX_ROUND_A: tl.constexpr = 0xD2511F53
        PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57
    else:
        tl.static_assert(c0.dtype == tl.uint64, "dtype not supported in philox_impl")
        PHILOX_KEY_A: tl.constexpr = 0x9E3779B97F4A7C15
        PHILOX_KEY_B: tl.constexpr = 0xBB67AE8584CAA73B
        PHILOX_ROUND_A: tl.constexpr = 0xD2E7470EE14C6C93
        PHILOX_ROUND_B: tl.constexpr = 0xCA5A826395121157

    for _ in tl.static_range(n_rounds):
        # for _ in range(n_rounds):
        # update random state
        A = PHILOX_ROUND_A
        B = PHILOX_ROUND_B
        _c0, _c2 = c0, c2
        c0 = math.umulhi(B, _c2) ^ c1 ^ k0
        c2 = math.umulhi(A, _c0) ^ c3 ^ k1
        c1 = tl.mul(B, _c2, sanitize_overflow=False)
        c3 = tl.mul(A, _c0, sanitize_overflow=False)
        # raise key
        k0 = tl.add(k0, PHILOX_KEY_A, sanitize_overflow=False)
        k1 = tl.add(k1, PHILOX_KEY_B, sanitize_overflow=False)
    return c0, c1, c2, c3


@jit
def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    seed = tl.to_tensor(seed)
    tl.static_assert(seed.dtype.is_int())
    seed = seed.to(tl.uint64)
    c0 = tl.to_tensor(c0)
    c1 = tl.to_tensor(c1)
    c2 = tl.to_tensor(c2)
    c3 = tl.to_tensor(c3)
    if tl.constexpr(c0.dtype.primitive_bitwidth) == 32:
        int_dtype = tl.uint32
        seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32)
        seed_lo = (seed & 0xffffffff).to(tl.uint32)
    else:
        tl.static_assert(tl.constexpr(c0.dtype.primitive_bitwidth) == 64, "bitwidth not supported in philox")
        int_dtype = tl.uint64
        seed_hi = tl.full((1, ), 0, dtype=int_dtype)
        seed_lo = seed
    c0 = c0.to(int_dtype, bitcast=True)
    c1 = c1.to(int_dtype, bitcast=True)
    c2 = c2.to(int_dtype, bitcast=True)
    c3 = c3.to(int_dtype, bitcast=True)
    return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)


@jit
def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Given a :code:`seed` scalar and an :code:`offset` block, returns a single
    block of random :code:`int32`.

    If you need multiple streams of random numbers,
    using `randint4x` is likely to be faster than calling `randint` 4 times.

    :param seed: The seed for generating random numbers.
    :param offset: The offsets to generate random numbers for.
    """
    ret, _, _, _ = randint4x(seed, offset, n_rounds)
    return ret


@jit
def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Given a :code:`seed` scalar and an :code:`offset` block, returns four
    blocks of random :code:`int32`.

    This is the maximally efficient entry point
    to Triton's Philox pseudo-random number generator.

    :param seed: The seed for generating random numbers.
    :param offsets: The offsets to generate random numbers for.
    """
    # _0 = tl.zeros(offset.shape, offset.dtype)
    _0 = offset * 0
    return philox(seed, offset, _0, _0, _0, n_rounds)


# -------------------
# rand
# -------------------

# @jit
# def uint32_to_uniform_float(x):
#     """
#     Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1).
#     """
#     two_to_the_minus_32: tl.constexpr = 2.328306e-10
#     return x * two_to_the_minus_32


@jit
def uint_to_uniform_float(x):
    """
    Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
    """
    # TODO: fix frontend issues and cleanup
    # conditions can be simplified
    # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
    if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):
        # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
        x = x.to(tl.int32, bitcast=True)
        scale = 4.6566127342e-10
    else:
        tl.static_assert(tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64))
        x = x.to(tl.int64, bitcast=True)
        scale = 1.0842020432385337e-19
    x = tl.where(x < 0, -x - 1, x)
    return x * scale


@jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Given a :code:`seed` scalar and an :code:`offset` block,
    returns a block of random :code:`float32` in :math:`U(0, 1)`.

    :param seed: The seed for generating random numbers.
    :param offsets: The offsets to generate random numbers for.
    """
    source = randint(seed, offset, n_rounds)
    return uint_to_uniform_float(source)


@jit
def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Given a :code:`seed` scalar and an :code:`offsets` block,
    returns 4 blocks of random :code:`float32` in :math:`U(0, 1)`.

    :param seed: The seed for generating random numbers.
    :param offsets: The offsets to generate random numbers for.
    """
    i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds)
    u1 = uint_to_uniform_float(i1)
    u2 = uint_to_uniform_float(i2)
    u3 = uint_to_uniform_float(i3)
    u4 = uint_to_uniform_float(i4)
    return u1, u2, u3, u4


# -------------------
# randn
# -------------------


@jit
def pair_uniform_to_normal(u1, u2):
    """Box-Muller transform"""
    u1 = tl.maximum(1.0e-7, u1)
    th = 6.283185307179586 * u2
    r = math.sqrt(-2.0 * math.log(u1))
    return r * math.cos(th), r * math.sin(th)


@jit
def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Given a :code:`seed` scalar and an :code:`offset` block,
    returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.

    :param seed: The seed for generating random numbers.
    :param offsets: The offsets to generate random numbers for.
    """
    i1, i2, _, _ = randint4x(seed, offset, n_rounds)
    u1 = uint_to_uniform_float(i1)
    u2 = uint_to_uniform_float(i2)
    n1, _ = pair_uniform_to_normal(u1, u2)
    return n1


@jit
def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):
    """
    Given a :code:`seed` scalar and an :code:`offset` block,
    returns 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`.

    :param seed: The seed for generating random numbers.
    :param offsets: The offsets to generate random numbers for.
    """
    u1, u2, u3, u4 = rand4x(seed, offset, n_rounds)
    n1, n2 = pair_uniform_to_normal(u1, u2)
    n3, n4 = pair_uniform_to_normal(u3, u4)
    return n1, n2, n3, n4
