import importlib
import json
import os
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional
import base64
import hashlib


def get_home_dir():
    return os.getenv("TRITON_HOME", Path.home())


def default_cache_dir():
    return os.path.join(get_home_dir(), ".triton", "cache")


def default_override_dir():
    return os.path.join(get_home_dir(), ".triton", "override")


def default_dump_dir():
    return os.path.join(get_home_dir(), ".triton", "dump")


class CacheManager(ABC):

    def __init__(self, key):
        pass

    @abstractmethod
    def get_file(self, filename) -> Optional[str]:
        pass

    @abstractmethod
    def put(self, data, filename, binary=True) -> str:
        pass

    @abstractmethod
    def get_group(self, filename: str) -> Optional[Dict[str, str]]:
        pass

    @abstractmethod
    def put_group(self, filename: str, group: Dict[str, str]):
        pass


class FileCacheManager(CacheManager):

    def __init__(self, key, override=False, dump=False):
        self.key = key
        self.lock_path = None
        if dump:
            self.cache_dir = os.getenv("TRITON_DUMP_DIR", "").strip() or default_dump_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
            self.lock_path = os.path.join(self.cache_dir, "lock")
            os.makedirs(self.cache_dir, exist_ok=True)
        elif override:
            self.cache_dir = os.getenv("TRITON_OVERRIDE_DIR", "").strip() or default_override_dir()
            self.cache_dir = os.path.join(self.cache_dir, self.key)
        else:
            # create cache directory if it doesn't exist
            self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
            if self.cache_dir:
                self.cache_dir = os.path.join(self.cache_dir, self.key)
                self.lock_path = os.path.join(self.cache_dir, "lock")
                os.makedirs(self.cache_dir, exist_ok=True)
            else:
                raise RuntimeError("Could not create or locate cache dir")

    def _make_path(self, filename) -> str:
        return os.path.join(self.cache_dir, filename)

    def has_file(self, filename) -> bool:
        if not self.cache_dir:
            raise RuntimeError("Could not create or locate cache dir")
        return os.path.exists(self._make_path(filename))

    def get_file(self, filename) -> Optional[str]:
        if self.has_file(filename):
            return self._make_path(filename)
        else:
            return None

    def get_group(self, filename: str) -> Optional[Dict[str, str]]:
        grp_filename = f"__grp__{filename}"
        if not self.has_file(grp_filename):
            return None
        grp_filepath = self._make_path(grp_filename)
        with open(grp_filepath) as f:
            grp_data = json.load(f)
        child_paths = grp_data.get("child_paths", None)
        # Invalid group data.
        if child_paths is None:
            return None
        result = {}
        for c, p in child_paths.items():
            if os.path.exists(p):
                result[c] = p
        return result

    # Note a group of pushed files as being part of a group
    def put_group(self, filename: str, group: Dict[str, str]) -> str:
        if not self.cache_dir:
            raise RuntimeError("Could not create or locate cache dir")
        grp_contents = json.dumps({"child_paths": group})
        grp_filename = f"__grp__{filename}"
        return self.put(grp_contents, grp_filename, binary=False)

    def put(self, data, filename, binary=True) -> str:
        if not self.cache_dir:
            raise RuntimeError("Could not create or locate cache dir")
        binary = isinstance(data, bytes)
        if not binary:
            data = str(data)
        assert self.lock_path is not None
        filepath = self._make_path(filename)
        # Random ID to avoid any collisions
        rnd_id = str(uuid.uuid4())
        # we use the PID in case a bunch of these around so we can see what PID made it
        pid = os.getpid()
        # use temp dir to be robust against program interruptions
        temp_dir = os.path.join(self.cache_dir, f"tmp.pid_{pid}_{rnd_id}")
        os.makedirs(temp_dir, exist_ok=True)
        temp_path = os.path.join(temp_dir, filename)

        mode = "wb" if binary else "w"
        with open(temp_path, mode) as f:
            f.write(data)
        # Replace is guaranteed to be atomic on POSIX systems if it succeeds
        # so filepath cannot see a partial write
        os.replace(temp_path, filepath)
        os.removedirs(temp_dir)
        return filepath


class RemoteCacheBackend:
    """
    A backend implementation for accessing a remote/distributed cache.
    """

    def __init__(self, key: str):
        pass

    @abstractmethod
    def get(self, filenames: List[str]) -> Dict[str, bytes]:
        pass

    @abstractmethod
    def put(self, filename: str, data: bytes):
        pass


class RedisRemoteCacheBackend(RemoteCacheBackend):

    def __init__(self, key):
        import redis
        self._key = key
        self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
        self._redis = redis.Redis(
            host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
            port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
        )

    def _get_key(self, filename: str) -> str:
        return self._key_fmt.format(key=self._key, filename=filename)

    def get(self, filenames: List[str]) -> Dict[str, str]:
        results = self._redis.mget([self._get_key(f) for f in filenames])
        return {filename: result for filename, result in zip(filenames, results) if result is not None}

    def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
        self._redis.set(self._get_key(filename), data)


class RemoteCacheManager(CacheManager):

    def __init__(self, key, override=False, dump=False):
        # Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
        remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
        module_path, clz_nme = remote_cache_manager.split(":")
        module = importlib.import_module(module_path)
        remote_cache_cls = getattr(module, clz_nme)
        self._backend = remote_cache_cls(key)

        self._override = override
        self._dump = dump

        # Use a `FileCacheManager` to materialize remote cache paths locally.
        self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)

    def _materialize(self, filename: str, data: bytes):
        # We use a backing `FileCacheManager` to provide the materialized data.
        return self._file_cache_manager.put(data, filename, binary=True)

    def get_file(self, filename: str) -> Optional[str]:
        # We don't handle the dump/override cases.
        if self._dump or self._override:
            return self._file_cache_manager.get_file(filename)

        # We always check the remote cache backend -- even if our internal file-
        # based cache has the item -- to make sure LRU accounting works as
        # expected.
        results = self._backend.get([filename])
        if len(results) == 0:
            return None
        (_, data), = results.items()
        return self._materialize(filename, data)

    def put(self, data, filename: str, binary=True) -> str:
        # We don't handle the dump/override cases.
        if self._dump or self._override:
            return self._file_cache_manager.put(data, filename, binary=binary)

        if not isinstance(data, bytes):
            data = str(data).encode("utf-8")
        self._backend.put(filename, data)
        return self._materialize(filename, data)

    def get_group(self, filename: str) -> Optional[Dict[str, str]]:
        # We don't handle the dump/override cases.
        if self._dump or self._override:
            return self._file_cache_manager.get_group(filename)

        grp_filename = f"__grp__{filename}"
        grp_filepath = self.get_file(grp_filename)
        if grp_filepath is None:
            return None
        with open(grp_filepath) as f:
            grp_data = json.load(f)
        child_paths = grp_data.get("child_paths", None)

        result = None

        # Found group data.
        if child_paths is not None:
            result = {}
            for child_path, data in self._backend.get(child_paths).items():
                result[child_path] = self._materialize(child_path, data)

        return result

    def put_group(self, filename: str, group: Dict[str, str]):
        # We don't handle the dump/override cases.
        if self._dump or self._override:
            return self._file_cache_manager.put_group(filename, group)

        grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
        grp_filename = f"__grp__{filename}"
        return self.put(grp_contents, grp_filename)


__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"


def _base32(key):
    # Assume key is a hex string.
    return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=")


def get_cache_manager(key) -> CacheManager:
    import os

    user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None)
    global __cache_cls
    global __cache_cls_nme

    if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
        module_path, clz_nme = user_cache_manager.split(":")
        module = importlib.import_module(module_path)
        __cache_cls = getattr(module, clz_nme)
        __cache_cls_nme = user_cache_manager

    return __cache_cls(_base32(key))


def get_override_manager(key) -> CacheManager:
    return __cache_cls(_base32(key), override=True)


def get_dump_manager(key) -> CacheManager:
    return __cache_cls(_base32(key), dump=True)


def make_so_cache_key(version_hash, signature, constants, ids, **kwargs):
    # Get unique key for the compiled code
    signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()}
    key = f"{version_hash}-{''.join(signature.values())}-{constants}-{ids}"
    for kw in kwargs:
        key = f"{key}-{kwargs.get(kw)}"
    key = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return _base32(key)
