import binascii
import json

try:
    from collections.abc import Iterable, Mapping
except ImportError:
    from collections import Mapping, Iterable

from jose import jwk
from jose.backends.base import Key
from jose.constants import ALGORITHMS
from jose.exceptions import JWSError, JWSSignatureError
from jose.utils import base64url_decode, base64url_encode


def sign(payload, key, headers=None, algorithm=ALGORITHMS.HS256):
    """Signs a claims set and returns a JWS string.

    Args:
        payload (str or dict): A string to sign
        key (str or dict): The key to use for signing the claim set. Can be
            individual JWK or JWK set.
        headers (dict, optional): A set of headers that will be added to
            the default headers.  Any headers that are added as additional
            headers will override the default headers.
        algorithm (str, optional): The algorithm to use for signing the
            the claims.  Defaults to HS256.

    Returns:
        str: The string representation of the header, claims, and signature.

    Raises:
        JWSError: If there is an error signing the token.

    Examples:

        >>> jws.sign({'a': 'b'}, 'secret', algorithm='HS256')
        'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'

    """

    if algorithm not in ALGORITHMS.SUPPORTED:
        raise JWSError("Algorithm %s not supported." % algorithm)

    encoded_header = _encode_header(algorithm, additional_headers=headers)
    encoded_payload = _encode_payload(payload)
    signed_output = _sign_header_and_claims(encoded_header, encoded_payload, algorithm, key)

    return signed_output


def verify(token, key, algorithms, verify=True):
    """Verifies a JWS string's signature.

    Args:
        token (str): A signed JWS to be verified.
        key (str or dict): A key to attempt to verify the payload with. Can be
            individual JWK or JWK set.
        algorithms (str or list): Valid algorithms that should be used to verify the JWS.

    Returns:
        str: The str representation of the payload, assuming the signature is valid.

    Raises:
        JWSError: If there is an exception verifying a token.

    Examples:

        >>> token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhIjoiYiJ9.jiMyrsmD8AoHWeQgmxZ5yq8z0lXS67_QGs52AzC8Ru8'
        >>> jws.verify(token, 'secret', algorithms='HS256')

    """

    header, payload, signing_input, signature = _load(token)

    if verify:
        _verify_signature(signing_input, header, signature, key, algorithms)

    return payload


def get_unverified_header(token):
    """Returns the decoded headers without verification of any kind.

    Args:
        token (str): A signed JWS to decode the headers from.

    Returns:
        dict: The dict representation of the token headers.

    Raises:
        JWSError: If there is an exception decoding the token.
    """
    header, claims, signing_input, signature = _load(token)
    return header


def get_unverified_headers(token):
    """Returns the decoded headers without verification of any kind.

    This is simply a wrapper of get_unverified_header() for backwards
    compatibility.

    Args:
        token (str): A signed JWS to decode the headers from.

    Returns:
        dict: The dict representation of the token headers.

    Raises:
        JWSError: If there is an exception decoding the token.
    """
    return get_unverified_header(token)


def get_unverified_claims(token):
    """Returns the decoded claims without verification of any kind.

    Args:
        token (str): A signed JWS to decode the headers from.

    Returns:
        str: The str representation of the token claims.

    Raises:
        JWSError: If there is an exception decoding the token.
    """
    header, claims, signing_input, signature = _load(token)
    return claims


def _encode_header(algorithm, additional_headers=None):
    header = {"typ": "JWT", "alg": algorithm}

    if additional_headers:
        header.update(additional_headers)

    json_header = json.dumps(
        header,
        separators=(",", ":"),
        sort_keys=True,
    ).encode("utf-8")

    return base64url_encode(json_header)


def _encode_payload(payload):
    if isinstance(payload, Mapping):
        try:
            payload = json.dumps(
                payload,
                separators=(",", ":"),
            ).encode("utf-8")
        except ValueError:
            pass

    return base64url_encode(payload)


def _sign_header_and_claims(encoded_header, encoded_claims, algorithm, key):
    signing_input = b".".join([encoded_header, encoded_claims])
    try:
        if not isinstance(key, Key):
            key = jwk.construct(key, algorithm)
        signature = key.sign(signing_input)
    except Exception as e:
        raise JWSError(e)

    encoded_signature = base64url_encode(signature)

    encoded_string = b".".join([encoded_header, encoded_claims, encoded_signature])

    return encoded_string.decode("utf-8")


def _load(jwt):
    if isinstance(jwt, str):
        jwt = jwt.encode("utf-8")
    try:
        signing_input, crypto_segment = jwt.rsplit(b".", 1)
        header_segment, claims_segment = signing_input.split(b".", 1)
        header_data = base64url_decode(header_segment)
    except ValueError:
        raise JWSError("Not enough segments")
    except (TypeError, binascii.Error):
        raise JWSError("Invalid header padding")

    try:
        header = json.loads(header_data.decode("utf-8"))
    except ValueError as e:
        raise JWSError("Invalid header string: %s" % e)

    if not isinstance(header, Mapping):
        raise JWSError("Invalid header string: must be a json object")

    try:
        payload = base64url_decode(claims_segment)
    except (TypeError, binascii.Error):
        raise JWSError("Invalid payload padding")

    try:
        signature = base64url_decode(crypto_segment)
    except (TypeError, binascii.Error):
        raise JWSError("Invalid crypto padding")

    return (header, payload, signing_input, signature)


def _sig_matches_keys(keys, signing_input, signature, alg):
    for key in keys:
        if not isinstance(key, Key):
            key = jwk.construct(key, alg)
        try:
            if key.verify(signing_input, signature):
                return True
        except Exception:
            pass
    return False


def _get_keys(key):
    if isinstance(key, Key):
        return (key,)

    try:
        key = json.loads(key, parse_int=str, parse_float=str)
    except Exception:
        pass

    if isinstance(key, Mapping):
        if "keys" in key:
            # JWK Set per RFC 7517
            return key["keys"]
        elif "kty" in key:
            # Individual JWK per RFC 7517
            return (key,)
        else:
            # Some other mapping. Firebase uses just dict of kid, cert pairs
            values = key.values()
            if values:
                return values
            return (key,)

    # Iterable but not text or mapping => list- or tuple-like
    elif isinstance(key, Iterable) and not (isinstance(key, str) or isinstance(key, bytes)):
        return key

    # Scalar value, wrap in tuple.
    else:
        return (key,)


def _verify_signature(signing_input, header, signature, key="", algorithms=None):
    alg = header.get("alg")
    if not alg:
        raise JWSError("No algorithm was specified in the JWS header.")

    if algorithms is not None and alg not in algorithms:
        raise JWSError("The specified alg value is not allowed")

    keys = _get_keys(key)
    try:
        if not _sig_matches_keys(keys, signing_input, signature, alg):
            raise JWSSignatureError()
    except JWSSignatureError:
        raise JWSError("Signature verification failed.")
    except JWSError:
        raise JWSError("Invalid or unsupported algorithm: %s" % alg)
