"""
Thin wrappers around `itertools`.
"""
import itertools

from ..auto import tqdm as tqdm_auto

__author__ = {"github.com/": ["casperdcl"]}
__all__ = [
    'chain', 'product', 'permutations', 'combinations', 'combinations_with_replacement', 'batched']


def chain(*iterables, total=None, tqdm_class=tqdm_auto, **kwargs):
    """Equivalent of `itertools.chain`."""
    if total is None:
        try:
            total = sum(map(len, iterables))
        except (TypeError, AttributeError):
            pass
    return tqdm_class(itertools.chain(*iterables), total=total, **kwargs)


def product(*iterables, repeat=1, total=None, tqdm_class=tqdm_auto, **kwargs):
    """Equivalent of `itertools.product`."""
    if total is None:
        try:
            lens = list(map(len, iterables))
        except (TypeError, AttributeError):
            pass
        else:  # py>=3.8: math.prod(lens) ** repeat
            total = 1
            for i in lens:
                total *= i
            total **= repeat
    for i in tqdm_class(itertools.product(*iterables, repeat=repeat), total=total, **kwargs):
        yield i


def permutations(iterable, r=None, total=None, tqdm_class=tqdm_auto, **kwargs):
    """Equivalent of `itertools.permutations`."""
    if total is None:
        try:
            n = len(iterable)
        except (TypeError, AttributeError):
            pass
        else:
            r = n if r is None else r
            if r > n:
                total = 0
            else:  # py>=3.8: math.perm(n, r)
                total = 1
                for i in range(n, n-r, -1):
                    total *= i
    return tqdm_class(itertools.permutations(iterable, r), total=total, **kwargs)


def combinations(iterable, r, total=None, tqdm_class=tqdm_auto, **kwargs):
    """Equivalent of `itertools.combinations`."""
    if total is None:
        try:
            n = len(iterable)
        except (TypeError, AttributeError):
            pass
        else:
            if r > n:
                total = 0
            else:  # py>=3.8: math.comb(n, r)
                total = 1
                for i in range(n, n-r, -1):
                    total *= i
                for i in range(1, r+1):
                    total //= i
    return tqdm_class(itertools.combinations(iterable, r), total=total, **kwargs)


def combinations_with_replacement(iterable, r, total=None, tqdm_class=tqdm_auto, **kwargs):
    """Equivalent of `itertools.combinations_with_replacement`."""
    if total is None:
        try:
            n = len(iterable)
        except (TypeError, AttributeError):
            pass
        else:
            total = 1
            for i in range(n+r-1, n-1, -1):
                total *= i
            for i in range(1, r+1):
                total //= i
    return tqdm_class(itertools.combinations_with_replacement(iterable, r), total=total, **kwargs)


def batched(iterable, n, total=None, tqdm_class=tqdm_auto, **kwargs):
    """Equivalent of `itertools.batched`."""
    if total is None:
        try:
            total = len(iterable)
        except (TypeError, AttributeError):
            pass
    return tqdm_class(itertools.batched(iterable, n), unit_scale=n,
                      total=(total+n-1) // n if total is not None else None,
                      **kwargs)
