"""Imported from the recipes section of the itertools documentation.

All functions taken from the recipes section of the itertools library docs
[1]_.
Some backward-compatible usability improvements have been made.

.. [1] http://docs.python.org/library/itertools.html#recipes

"""

import random

from collections import deque
from contextlib import suppress
from collections.abc import Sized
from functools import lru_cache, partial
from itertools import (
    accumulate,
    chain,
    combinations,
    compress,
    count,
    cycle,
    groupby,
    islice,
    product,
    repeat,
    starmap,
    tee,
    zip_longest,
)
from math import prod, comb, isqrt, gcd
from operator import mul, not_, itemgetter, getitem
from random import randrange, sample, choice
from sys import hexversion

__all__ = [
    'all_equal',
    'batched',
    'before_and_after',
    'consume',
    'convolve',
    'dotproduct',
    'first_true',
    'factor',
    'flatten',
    'grouper',
    'is_prime',
    'iter_except',
    'iter_index',
    'loops',
    'matmul',
    'multinomial',
    'ncycles',
    'nth',
    'nth_combination',
    'padnone',
    'pad_none',
    'pairwise',
    'partition',
    'polynomial_eval',
    'polynomial_from_roots',
    'polynomial_derivative',
    'powerset',
    'prepend',
    'quantify',
    'reshape',
    'random_combination_with_replacement',
    'random_combination',
    'random_permutation',
    'random_product',
    'repeatfunc',
    'roundrobin',
    'sieve',
    'sliding_window',
    'subslices',
    'sum_of_squares',
    'tabulate',
    'tail',
    'take',
    'totient',
    'transpose',
    'triplewise',
    'unique',
    'unique_everseen',
    'unique_justseen',
]

_marker = object()


# zip with strict is available for Python 3.10+
try:
    zip(strict=True)
except TypeError:
    _zip_strict = zip
else:
    _zip_strict = partial(zip, strict=True)


# math.sumprod is available for Python 3.12+
try:
    from math import sumprod as _sumprod
except ImportError:
    _sumprod = lambda x, y: dotproduct(x, y)


def take(n, iterable):
    """Return first *n* items of the *iterable* as a list.

        >>> take(3, range(10))
        [0, 1, 2]

    If there are fewer than *n* items in the iterable, all of them are
    returned.

        >>> take(10, range(3))
        [0, 1, 2]

    """
    return list(islice(iterable, n))


def tabulate(function, start=0):
    """Return an iterator over the results of ``func(start)``,
    ``func(start + 1)``, ``func(start + 2)``...

    *func* should be a function that accepts one integer argument.

    If *start* is not specified it defaults to 0. It will be incremented each
    time the iterator is advanced.

        >>> square = lambda x: x ** 2
        >>> iterator = tabulate(square, -3)
        >>> take(4, iterator)
        [9, 4, 1, 0]

    """
    return map(function, count(start))


def tail(n, iterable):
    """Return an iterator over the last *n* items of *iterable*.

    >>> t = tail(3, 'ABCDEFG')
    >>> list(t)
    ['E', 'F', 'G']

    """
    # If the given iterable has a length, then we can use islice to get its
    # final elements. Note that if the iterable is not actually Iterable,
    # either islice or deque will throw a TypeError. This is why we don't
    # check if it is Iterable.
    if isinstance(iterable, Sized):
        return islice(iterable, max(0, len(iterable) - n), None)
    else:
        return iter(deque(iterable, maxlen=n))


def consume(iterator, n=None):
    """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
    entirely.

    Efficiently exhausts an iterator without returning values. Defaults to
    consuming the whole iterator, but an optional second argument may be
    provided to limit consumption.

        >>> i = (x for x in range(10))
        >>> next(i)
        0
        >>> consume(i, 3)
        >>> next(i)
        4
        >>> consume(i)
        >>> next(i)
        Traceback (most recent call last):
          File "<stdin>", line 1, in <module>
        StopIteration

    If the iterator has fewer items remaining than the provided limit, the
    whole iterator will be consumed.

        >>> i = (x for x in range(3))
        >>> consume(i, 5)
        >>> next(i)
        Traceback (most recent call last):
          File "<stdin>", line 1, in <module>
        StopIteration

    """
    # Use functions that consume iterators at C speed.
    if n is None:
        # feed the entire iterator into a zero-length deque
        deque(iterator, maxlen=0)
    else:
        # advance to the empty slice starting at position n
        next(islice(iterator, n, n), None)


def nth(iterable, n, default=None):
    """Returns the nth item or a default value.

    >>> l = range(10)
    >>> nth(l, 3)
    3
    >>> nth(l, 20, "zebra")
    'zebra'

    """
    return next(islice(iterable, n, None), default)


def all_equal(iterable, key=None):
    """
    Returns ``True`` if all the elements are equal to each other.

        >>> all_equal('aaaa')
        True
        >>> all_equal('aaab')
        False

    A function that accepts a single argument and returns a transformed version
    of each input item can be specified with *key*:

        >>> all_equal('AaaA', key=str.casefold)
        True
        >>> all_equal([1, 2, 3], key=lambda x: x < 10)
        True

    """
    iterator = groupby(iterable, key)
    for first in iterator:
        for second in iterator:
            return False
        return True
    return True


def quantify(iterable, pred=bool):
    """Return the how many times the predicate is true.

    >>> quantify([True, False, True])
    2

    """
    return sum(map(pred, iterable))


def pad_none(iterable):
    """Returns the sequence of elements and then returns ``None`` indefinitely.

        >>> take(5, pad_none(range(3)))
        [0, 1, 2, None, None]

    Useful for emulating the behavior of the built-in :func:`map` function.

    See also :func:`padded`.

    """
    return chain(iterable, repeat(None))


padnone = pad_none


def ncycles(iterable, n):
    """Returns the sequence elements *n* times

    >>> list(ncycles(["a", "b"], 3))
    ['a', 'b', 'a', 'b', 'a', 'b']

    """
    return chain.from_iterable(repeat(tuple(iterable), n))


def dotproduct(vec1, vec2):
    """Returns the dot product of the two iterables.

    >>> dotproduct([10, 15, 12], [0.65, 0.80, 1.25])
    33.5
    >>> 10 * 0.65 + 15 * 0.80 + 12 * 1.25
    33.5

    In Python 3.12 and later, use ``math.sumprod()`` instead.
    """
    return sum(map(mul, vec1, vec2))


def flatten(listOfLists):
    """Return an iterator flattening one level of nesting in a list of lists.

        >>> list(flatten([[0, 1], [2, 3]]))
        [0, 1, 2, 3]

    See also :func:`collapse`, which can flatten multiple levels of nesting.

    """
    return chain.from_iterable(listOfLists)


def repeatfunc(func, times=None, *args):
    """Call *func* with *args* repeatedly, returning an iterable over the
    results.

    If *times* is specified, the iterable will terminate after that many
    repetitions:

        >>> from operator import add
        >>> times = 4
        >>> args = 3, 5
        >>> list(repeatfunc(add, times, *args))
        [8, 8, 8, 8]

    If *times* is ``None`` the iterable will not terminate:

        >>> from random import randrange
        >>> times = None
        >>> args = 1, 11
        >>> take(6, repeatfunc(randrange, times, *args))  # doctest:+SKIP
        [2, 4, 8, 1, 8, 4]

    """
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))


def _pairwise(iterable):
    """Returns an iterator of paired items, overlapping, from the original

    >>> take(4, pairwise(count()))
    [(0, 1), (1, 2), (2, 3), (3, 4)]

    On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.

    """
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


try:
    from itertools import pairwise as itertools_pairwise
except ImportError:
    pairwise = _pairwise
else:

    def pairwise(iterable):
        return itertools_pairwise(iterable)

    pairwise.__doc__ = _pairwise.__doc__


class UnequalIterablesError(ValueError):
    def __init__(self, details=None):
        msg = 'Iterables have different lengths'
        if details is not None:
            msg += (': index 0 has length {}; index {} has length {}').format(
                *details
            )

        super().__init__(msg)


def _zip_equal_generator(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo


def _zip_equal(*iterables):
    # Check whether the iterables are all the same size.
    try:
        first_size = len(iterables[0])
        for i, it in enumerate(iterables[1:], 1):
            size = len(it)
            if size != first_size:
                raise UnequalIterablesError(details=(first_size, i, size))
        # All sizes are equal, we can use the built-in zip.
        return zip(*iterables)
    # If any one of the iterables didn't have a length, start reading
    # them until one runs out.
    except TypeError:
        return _zip_equal_generator(iterables)


def grouper(iterable, n, incomplete='fill', fillvalue=None):
    """Group elements from *iterable* into fixed-length groups of length *n*.

    >>> list(grouper('ABCDEF', 3))
    [('A', 'B', 'C'), ('D', 'E', 'F')]

    The keyword arguments *incomplete* and *fillvalue* control what happens for
    iterables whose length is not a multiple of *n*.

    When *incomplete* is `'fill'`, the last group will contain instances of
    *fillvalue*.

    >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
    [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]

    When *incomplete* is `'ignore'`, the last group will not be emitted.

    >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
    [('A', 'B', 'C'), ('D', 'E', 'F')]

    When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.

    >>> iterator = grouper('ABCDEFG', 3, incomplete='strict')
    >>> list(iterator)  # doctest: +IGNORE_EXCEPTION_DETAIL
    Traceback (most recent call last):
    ...
    UnequalIterablesError

    """
    iterators = [iter(iterable)] * n
    if incomplete == 'fill':
        return zip_longest(*iterators, fillvalue=fillvalue)
    if incomplete == 'strict':
        return _zip_equal(*iterators)
    if incomplete == 'ignore':
        return zip(*iterators)
    else:
        raise ValueError('Expected fill, strict, or ignore')


def roundrobin(*iterables):
    """Visit input iterables in a cycle until each is exhausted.

        >>> list(roundrobin('ABC', 'D', 'EF'))
        ['A', 'D', 'E', 'B', 'F', 'C']

    This function produces the same output as :func:`interleave_longest`, but
    may perform better for some inputs (in particular when the number of
    iterables is small).

    """
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)


def partition(pred, iterable):
    """
    Returns a 2-tuple of iterables derived from the input iterable.
    The first yields the items that have ``pred(item) == False``.
    The second yields the items that have ``pred(item) == True``.

        >>> is_odd = lambda x: x % 2 != 0
        >>> iterable = range(10)
        >>> even_items, odd_items = partition(is_odd, iterable)
        >>> list(even_items), list(odd_items)
        ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])

    If *pred* is None, :func:`bool` is used.

        >>> iterable = [0, 1, False, True, '', ' ']
        >>> false_items, true_items = partition(None, iterable)
        >>> list(false_items), list(true_items)
        ([0, False, ''], [1, True, ' '])

    """
    if pred is None:
        pred = bool

    t1, t2, p = tee(iterable, 3)
    p1, p2 = tee(map(pred, p))
    return (compress(t1, map(not_, p1)), compress(t2, p2))


def powerset(iterable):
    """Yields all possible subsets of the iterable.

        >>> list(powerset([1, 2, 3]))
        [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]

    :func:`powerset` will operate on iterables that aren't :class:`set`
    instances, so repeated elements in the input will produce repeated elements
    in the output.

        >>> seq = [1, 1, 0]
        >>> list(powerset(seq))
        [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]

    For a variant that efficiently yields actual :class:`set` instances, see
    :func:`powerset_of_sets`.
    """
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))


def unique_everseen(iterable, key=None):
    """
    Yield unique elements, preserving order.

        >>> list(unique_everseen('AAAABBBCCDAABBB'))
        ['A', 'B', 'C', 'D']
        >>> list(unique_everseen('ABBCcAD', str.lower))
        ['A', 'B', 'C', 'D']

    Sequences with a mix of hashable and unhashable items can be used.
    The function will be slower (i.e., `O(n^2)`) for unhashable items.

    Remember that ``list`` objects are unhashable - you can use the *key*
    parameter to transform the list to a tuple (which is hashable) to
    avoid a slowdown.

        >>> iterable = ([1, 2], [2, 3], [1, 2])
        >>> list(unique_everseen(iterable))  # Slow
        [[1, 2], [2, 3]]
        >>> list(unique_everseen(iterable, key=tuple))  # Faster
        [[1, 2], [2, 3]]

    Similarly, you may want to convert unhashable ``set`` objects with
    ``key=frozenset``. For ``dict`` objects,
    ``key=lambda x: frozenset(x.items())`` can be used.

    """
    seenset = set()
    seenset_add = seenset.add
    seenlist = []
    seenlist_add = seenlist.append
    use_key = key is not None

    for element in iterable:
        k = key(element) if use_key else element
        try:
            if k not in seenset:
                seenset_add(k)
                yield element
        except TypeError:
            if k not in seenlist:
                seenlist_add(k)
                yield element


def unique_justseen(iterable, key=None):
    """Yields elements in order, ignoring serial duplicates

    >>> list(unique_justseen('AAAABBBCCDAABBB'))
    ['A', 'B', 'C', 'D', 'A', 'B']
    >>> list(unique_justseen('ABBCcAD', str.lower))
    ['A', 'B', 'C', 'A', 'D']

    """
    if key is None:
        return map(itemgetter(0), groupby(iterable))

    return map(next, map(itemgetter(1), groupby(iterable, key)))


def unique(iterable, key=None, reverse=False):
    """Yields unique elements in sorted order.

    >>> list(unique([[1, 2], [3, 4], [1, 2]]))
    [[1, 2], [3, 4]]

    *key* and *reverse* are passed to :func:`sorted`.

    >>> list(unique('ABBcCAD', str.casefold))
    ['A', 'B', 'c', 'D']
    >>> list(unique('ABBcCAD', str.casefold, reverse=True))
    ['D', 'c', 'B', 'A']

    The elements in *iterable* need not be hashable, but they must be
    comparable for sorting to work.
    """
    sequenced = sorted(iterable, key=key, reverse=reverse)
    return unique_justseen(sequenced, key=key)


def iter_except(func, exception, first=None):
    """Yields results from a function repeatedly until an exception is raised.

    Converts a call-until-exception interface to an iterator interface.
    Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
    to end the loop.

        >>> l = [0, 1, 2]
        >>> list(iter_except(l.pop, IndexError))
        [2, 1, 0]

    Multiple exceptions can be specified as a stopping condition:

        >>> l = [1, 2, 3, '...', 4, 5, 6]
        >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
        [7, 6, 5]
        >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
        [4, 3, 2]
        >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
        []

    """
    with suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield func()


def first_true(iterable, default=None, pred=None):
    """
    Returns the first true value in the iterable.

    If no true value is found, returns *default*

    If *pred* is not None, returns the first item for which
    ``pred(item) == True`` .

        >>> first_true(range(10))
        1
        >>> first_true(range(10), pred=lambda x: x > 5)
        6
        >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
        'missing'

    """
    return next(filter(pred, iterable), default)


def random_product(*args, repeat=1):
    """Draw an item at random from each of the input iterables.

        >>> random_product('abc', range(4), 'XYZ')  # doctest:+SKIP
        ('c', 3, 'Z')

    If *repeat* is provided as a keyword argument, that many items will be
    drawn from each iterable.

        >>> random_product('abcd', range(4), repeat=2)  # doctest:+SKIP
        ('a', 2, 'd', 3)

    This equivalent to taking a random selection from
    ``itertools.product(*args, **kwarg)``.

    """
    pools = [tuple(pool) for pool in args] * repeat
    return tuple(choice(pool) for pool in pools)


def random_permutation(iterable, r=None):
    """Return a random *r* length permutation of the elements in *iterable*.

    If *r* is not specified or is ``None``, then *r* defaults to the length of
    *iterable*.

        >>> random_permutation(range(5))  # doctest:+SKIP
        (3, 4, 0, 1, 2)

    This equivalent to taking a random selection from
    ``itertools.permutations(iterable, r)``.

    """
    pool = tuple(iterable)
    r = len(pool) if r is None else r
    return tuple(sample(pool, r))


def random_combination(iterable, r):
    """Return a random *r* length subsequence of the elements in *iterable*.

        >>> random_combination(range(5), 3)  # doctest:+SKIP
        (2, 3, 4)

    This equivalent to taking a random selection from
    ``itertools.combinations(iterable, r)``.

    """
    pool = tuple(iterable)
    n = len(pool)
    indices = sorted(sample(range(n), r))
    return tuple(pool[i] for i in indices)


def random_combination_with_replacement(iterable, r):
    """Return a random *r* length subsequence of elements in *iterable*,
    allowing individual elements to be repeated.

        >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
        (0, 0, 1, 2, 2)

    This equivalent to taking a random selection from
    ``itertools.combinations_with_replacement(iterable, r)``.

    """
    pool = tuple(iterable)
    n = len(pool)
    indices = sorted(randrange(n) for i in range(r))
    return tuple(pool[i] for i in indices)


def nth_combination(iterable, r, index):
    """Equivalent to ``list(combinations(iterable, r))[index]``.

    The subsequences of *iterable* that are of length *r* can be ordered
    lexicographically. :func:`nth_combination` computes the subsequence at
    sort position *index* directly, without computing the previous
    subsequences.

        >>> nth_combination(range(5), 3, 5)
        (0, 3, 4)

    ``ValueError`` will be raised If *r* is negative or greater than the length
    of *iterable*.
    ``IndexError`` will be raised if the given *index* is invalid.
    """
    pool = tuple(iterable)
    n = len(pool)
    if (r < 0) or (r > n):
        raise ValueError

    c = 1
    k = min(r, n - r)
    for i in range(1, k + 1):
        c = c * (n - k + i) // i

    if index < 0:
        index += c

    if (index < 0) or (index >= c):
        raise IndexError

    result = []
    while r:
        c, n, r = c * r // n, n - 1, r - 1
        while index >= c:
            index -= c
            c, n = c * (n - r) // n, n - 1
        result.append(pool[-1 - n])

    return tuple(result)


def prepend(value, iterator):
    """Yield *value*, followed by the elements in *iterator*.

        >>> value = '0'
        >>> iterator = ['1', '2', '3']
        >>> list(prepend(value, iterator))
        ['0', '1', '2', '3']

    To prepend multiple values, see :func:`itertools.chain`
    or :func:`value_chain`.

    """
    return chain([value], iterator)


def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    For example, multiplying ``(x² -x - 20)`` by ``(x - 3)``
    gives ``(x³ -4x² -17x + 60)``.

        >>> list(convolve([1, -1, -20], [1, -3]))
        [1, -4, -17, 60]

    Examples of popular kinds of kernels:

    * The kernel ``[0.25, 0.25, 0.25, 0.25]`` computes a moving average.
      For image data, this blurs the image and reduces noise.
    * The kernel ``[1/2, 0, -1/2]`` estimates the first derivative of
      a function evaluated at evenly spaced inputs.
    * The kernel ``[1, -2, 1]`` estimates the second derivative of a
      function evaluated at evenly spaced inputs.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Supports all numeric types: int, float, complex, Decimal, Fraction.

    References:

    * Article:  https://betterexplained.com/articles/intuitive-convolution/
    * Video by 3Blue1Brown:  https://www.youtube.com/watch?v=KuXjwB4LzSA

    """
    # This implementation comes from an older version of the itertools
    # documentation.  While the newer implementation is a bit clearer,
    # this one was kept because the inlined window logic is faster
    # and it avoids an unnecessary deque-to-tuple conversion.
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    window = deque([0], maxlen=n) * n
    for x in chain(signal, repeat(0, n - 1)):
        window.append(x)
        yield _sumprod(kernel, window)


def before_and_after(predicate, it):
    """A variant of :func:`takewhile` that allows complete access to the
    remainder of the iterator.

         >>> it = iter('ABCdEfGhI')
         >>> all_upper, remainder = before_and_after(str.isupper, it)
         >>> ''.join(all_upper)
         'ABC'
         >>> ''.join(remainder) # takewhile() would lose the 'd'
         'dEfGhI'

    Note that the first iterator must be fully consumed before the second
    iterator can generate valid results.
    """
    it = iter(it)
    transition = []

    def true_iterator():
        for elem in it:
            if predicate(elem):
                yield elem
            else:
                transition.append(elem)
                return

    # Note: this is different from itertools recipes to allow nesting
    # before_and_after remainders into before_and_after again. See tests
    # for an example.
    remainder_iterator = chain(transition, it)

    return true_iterator(), remainder_iterator


def triplewise(iterable):
    """Return overlapping triplets from *iterable*.

    >>> list(triplewise('ABCDE'))
    [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]

    """
    # This deviates from the itertools documentation reciple - see
    # https://github.com/more-itertools/more-itertools/issues/889
    t1, t2, t3 = tee(iterable, 3)
    next(t3, None)
    next(t3, None)
    next(t2, None)
    return zip(t1, t2, t3)


def _sliding_window_islice(iterable, n):
    # Fast path for small, non-zero values of n.
    iterators = tee(iterable, n)
    for i, iterator in enumerate(iterators):
        next(islice(iterator, i, i), None)
    return zip(*iterators)


def _sliding_window_deque(iterable, n):
    # Normal path for other values of n.
    iterator = iter(iterable)
    window = deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)


def sliding_window(iterable, n):
    """Return a sliding window of width *n* over *iterable*.

        >>> list(sliding_window(range(6), 4))
        [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]

    If *iterable* has fewer than *n* items, then nothing is yielded:

        >>> list(sliding_window(range(3), 4))
        []

    For a variant with more features, see :func:`windowed`.
    """
    if n > 20:
        return _sliding_window_deque(iterable, n)
    elif n > 2:
        return _sliding_window_islice(iterable, n)
    elif n == 2:
        return pairwise(iterable)
    elif n == 1:
        return zip(iterable)
    else:
        raise ValueError(f'n should be at least one, not {n}')


def subslices(iterable):
    """Return all contiguous non-empty subslices of *iterable*.

        >>> list(subslices('ABC'))
        [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]

    This is similar to :func:`substrings`, but emits items in a different
    order.
    """
    seq = list(iterable)
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(getitem, repeat(seq), slices)


def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

    >>> roots = [5, -4, 3]            # (x - 5) * (x + 4) * (x - 3)
    >>> polynomial_from_roots(roots)  # x³ - 4 x² - 17 x + 60
    [1, -4, -17, 60]

    Supports all numeric types: int, float, complex, Decimal, Fraction.
    """
    # This recipe differs from the one in itertools docs in that it
    # applies list() after each call to convolve().  This avoids
    # hitting stack limits with nested generators.
    poly = [1]
    for root in roots:
        poly = list(convolve(poly, (1, -root)))
    return poly


def iter_index(iterable, value, start=0, stop=None):
    """Yield the index of each place in *iterable* that *value* occurs,
    beginning with index *start* and ending before index *stop*.


    >>> list(iter_index('AABCADEAF', 'A'))
    [0, 1, 4, 7]
    >>> list(iter_index('AABCADEAF', 'A', 1))  # start index is inclusive
    [1, 4, 7]
    >>> list(iter_index('AABCADEAF', 'A', 1, 7))  # stop index is not inclusive
    [1, 4]

    The behavior for non-scalar *values* matches the built-in Python types.

    >>> list(iter_index('ABCDABCD', 'AB'))
    [0, 4]
    >>> list(iter_index([0, 1, 2, 3, 0, 1, 2, 3], [0, 1]))
    []
    >>> list(iter_index([[0, 1], [2, 3], [0, 1], [2, 3]], [0, 1]))
    [0, 2]

    See :func:`locate` for a more general means of finding the indexes
    associated with particular values.

    """
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        # Slow path for general iterables
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        # Fast path for sequences
        stop = len(iterable) if stop is None else stop
        i = start - 1
        with suppress(ValueError):
            while True:
                yield (i := seq_index(value, i + 1, stop))


def sieve(n):
    """Yield the primes less than n.

    >>> list(sieve(30))
    [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]

    """
    # This implementation comes from an older version of the itertools
    # documentation.  The newer implementation is easier to read but is
    # less lazy.
    if n > 2:
        yield 2
    start = 3
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start, stop=isqrt(n) + 1):
        yield from iter_index(data, 1, start, p * p)
        data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
        start = p * p
    yield from iter_index(data, 1, start)


def _batched(iterable, n, *, strict=False):
    """Batch data into tuples of length *n*. If the number of items in
    *iterable* is not divisible by *n*:
    * The last batch will be shorter if *strict* is ``False``.
    * :exc:`ValueError` will be raised if *strict* is ``True``.

    >>> list(batched('ABCDEFG', 3))
    [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]

    On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
    """
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch


if hexversion >= 0x30D00A2:  # pragma: no cover
    from itertools import batched as itertools_batched

    def batched(iterable, n, *, strict=False):
        return itertools_batched(iterable, n, strict=strict)

else:
    batched = _batched

    batched.__doc__ = _batched.__doc__


def transpose(it):
    """Swap the rows and columns of the input matrix.

    >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
    [(1, 11), (2, 22), (3, 33)]

    The caller should ensure that the dimensions of the input are compatible.
    If the input is empty, no output will be produced.
    """
    return _zip_strict(*it)


def reshape(matrix, cols):
    """Reshape the 2-D input *matrix* to have a column count given by *cols*.

    >>> matrix = [(0, 1), (2, 3), (4, 5)]
    >>> cols = 3
    >>> list(reshape(matrix, cols))
    [(0, 1, 2), (3, 4, 5)]
    """
    return batched(chain.from_iterable(matrix), cols)


def matmul(m1, m2):
    """Multiply two matrices.

    >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
    [(49, 80), (41, 60)]

    The caller should ensure that the dimensions of the input matrices are
    compatible with each other.

    Supports all numeric types: int, float, complex, Decimal, Fraction.
    """
    n = len(m2[0])
    return batched(starmap(_sumprod, product(m1, transpose(m2))), n)


def _factor_pollard(n):
    # Return a factor of n using Pollard's rho algorithm.
    # Efficient when n is odd and composite.
    for b in range(1, n):
        x = y = 2
        d = 1
        while d == 1:
            x = (x * x + b) % n
            y = (y * y + b) % n
            y = (y * y + b) % n
            d = gcd(x - y, n)
        if d != n:
            return d
    raise ValueError('prime or under 5')


_primes_below_211 = tuple(sieve(211))


def factor(n):
    """Yield the prime factors of n.

    >>> list(factor(360))
    [2, 2, 2, 3, 3, 5]

    Finds small factors with trial division.  Larger factors are
    either verified as prime with ``is_prime`` or split into
    smaller factors with Pollard's rho algorithm.
    """

    # Corner case reduction
    if n < 2:
        return

    # Trial division reduction
    for prime in _primes_below_211:
        while not n % prime:
            yield prime
            n //= prime

    # Pollard's rho reduction
    primes = []
    todo = [n] if n > 1 else []
    for n in todo:
        if n < 211**2 or is_prime(n):
            primes.append(n)
        else:
            fact = _factor_pollard(n)
            todo += (fact, n // fact)
    yield from sorted(primes)


def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.

    Evaluate ``x^3 - 4 * x^2 - 17 * x + 60`` at ``x = 2.5``:

    >>> coefficients = [1, -4, -17, 60]
    >>> x = 2.5
    >>> polynomial_eval(coefficients, x)
    8.125

    Supports all numeric types: int, float, complex, Decimal, Fraction.
    """
    n = len(coefficients)
    if n == 0:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return _sumprod(coefficients, powers)


def sum_of_squares(it):
    """Return the sum of the squares of the input values.

    >>> sum_of_squares([10, 20, 30])
    1400

    Supports all numeric types: int, float, complex, Decimal, Fraction.
    """
    return _sumprod(*tee(it))


def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

    Evaluate the derivative of ``x³ - 4 x² - 17 x + 60``:

    >>> coefficients = [1, -4, -17, 60]
    >>> derivative_coefficients = polynomial_derivative(coefficients)
    >>> derivative_coefficients
    [3, -8, -17]

    Supports all numeric types: int, float, complex, Decimal, Fraction.
    """
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(mul, coefficients, powers))


def totient(n):
    """Return the count of natural numbers up to *n* that are coprime with *n*.

    Euler's totient function φ(n) gives the number of totatives.
    Totative are integers k in the range 1 ≤ k ≤ n such that gcd(n, k) = 1.

    >>> n = 9
    >>> totient(n)
    6

    >>> totatives = [x for x in range(1, n) if gcd(n, x) == 1]
    >>> totatives
    [1, 2, 4, 5, 7, 8]
    >>> len(totatives)
    6

    Reference:  https://en.wikipedia.org/wiki/Euler%27s_totient_function

    """
    for prime in set(factor(n)):
        n -= n // prime
    return n


# Miller–Rabin primality test: https://oeis.org/A014233
_perfect_tests = [
    (2047, (2,)),
    (9080191, (31, 73)),
    (4759123141, (2, 7, 61)),
    (1122004669633, (2, 13, 23, 1662803)),
    (2152302898747, (2, 3, 5, 7, 11)),
    (3474749660383, (2, 3, 5, 7, 11, 13)),
    (18446744073709551616, (2, 325, 9375, 28178, 450775, 9780504, 1795265022)),
    (
        3317044064679887385961981,
        (2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41),
    ),
]


@lru_cache
def _shift_to_odd(n):
    'Return s, d such that 2**s * d == n'
    s = ((n - 1) ^ n).bit_length() - 1
    d = n >> s
    assert (1 << s) * d == n and d & 1 and s >= 0
    return s, d


def _strong_probable_prime(n, base):
    assert (n > 2) and (n & 1) and (2 <= base < n)

    s, d = _shift_to_odd(n - 1)

    x = pow(base, d, n)
    if x == 1 or x == n - 1:
        return True

    for _ in range(s - 1):
        x = x * x % n
        if x == n - 1:
            return True

    return False


# Separate instance of Random() that doesn't share state
# with the default user instance of Random().
_private_randrange = random.Random().randrange


def is_prime(n):
    """Return ``True`` if *n* is prime and ``False`` otherwise.

    Basic examples:

        >>> is_prime(37)
        True
        >>> is_prime(3 * 13)
        False
        >>> is_prime(18_446_744_073_709_551_557)
        True

    Find the next prime over one billion:

        >>> next(filter(is_prime, count(10**9)))
        1000000007

    Generate random primes up to 200 bits and up to 60 decimal digits:

        >>> from random import seed, randrange, getrandbits
        >>> seed(18675309)

        >>> next(filter(is_prime, map(getrandbits, repeat(200))))
        893303929355758292373272075469392561129886005037663238028407

        >>> next(filter(is_prime, map(randrange, repeat(10**60))))
        269638077304026462407872868003560484232362454342414618963649

    This function is exact for values of *n* below 10**24.  For larger inputs,
    the probabilistic Miller-Rabin primality test has a less than 1 in 2**128
    chance of a false positive.
    """

    if n < 17:
        return n in {2, 3, 5, 7, 11, 13}

    if not (n & 1 and n % 3 and n % 5 and n % 7 and n % 11 and n % 13):
        return False

    for limit, bases in _perfect_tests:
        if n < limit:
            break
    else:
        bases = (_private_randrange(2, n - 1) for i in range(64))

    return all(_strong_probable_prime(n, base) for base in bases)


def loops(n):
    """Returns an iterable with *n* elements for efficient looping.
    Like ``range(n)`` but doesn't create integers.

    >>> i = 0
    >>> for _ in loops(5):
    ...     i += 1
    >>> i
    5

    """
    return repeat(None, n)


def multinomial(*counts):
    """Number of distinct arrangements of a multiset.

    The expression ``multinomial(3, 4, 2)`` has several equivalent
    interpretations:

    * In the expansion of ``(a + b + c)⁹``, the coefficient of the
      ``a³b⁴c²`` term is 1260.

    * There are 1260 distinct ways to arrange 9 balls consisting of 3 reds, 4
      greens, and 2 blues.

    * There are 1260 unique ways to place 9 distinct objects into three bins
      with sizes 3, 4, and 2.

    The :func:`multinomial` function computes the length of
    :func:`distinct_permutations`.  For example, there are 83,160 distinct
    anagrams of the word "abracadabra":

        >>> from more_itertools import distinct_permutations, ilen
        >>> ilen(distinct_permutations('abracadabra'))
        83160

    This can be computed directly from the letter counts, 5a 2b 2r 1c 1d:

        >>> from collections import Counter
        >>> list(Counter('abracadabra').values())
        [5, 2, 2, 1, 1]
        >>> multinomial(5, 2, 1, 1, 2)
        83160

    A binomial coefficient is a special case of multinomial where there are
    only two categories.  For example, the number of ways to arrange 12 balls
    with 5 reds and 7 blues is ``multinomial(5, 7)`` or ``math.comb(12, 5)``.

    When the multiplicities are all just 1, :func:`multinomial`
    is a special case of ``math.factorial`` so that
    ``multinomial(1, 1, 1, 1, 1, 1, 1) == math.factorial(7)``.

    Reference:  https://en.wikipedia.org/wiki/Multinomial_theorem

    """
    return prod(map(comb, accumulate(counts), counts))
