# mypy: allow-untyped-defs
from typing import Any, Optional

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.dataframe.structures import DataChunkDF
from torch.utils.data.datapipes.datapipe import DFIterDataPipe, IterDataPipe


# TODO(VitalyFedyunin): Add error when two different traces get combined

__all__ = [
    "Capture",
    "CaptureA",
    "CaptureAdd",
    "CaptureCall",
    "CaptureControl",
    "CaptureDataFrame",
    "CaptureDataFrameWithDataPipeOps",
    "CaptureF",
    "CaptureGetAttr",
    "CaptureGetItem",
    "CaptureInitial",
    "CaptureLikeMock",
    "CaptureMul",
    "CaptureSetItem",
    "CaptureSub",
    "CaptureVariable",
    "CaptureVariableAssign",
    "DataFrameTracer",
    "DataFrameTracedOps",
    "disable_capture",
    "get_val",
]


def disable_capture():
    CaptureControl.disabled = True


class CaptureControl:
    disabled = False


class DataFrameTracedOps(DFIterDataPipe):
    def __init__(self, source_datapipe, output_var):
        self.source_datapipe = source_datapipe
        self.output_var = output_var

    def __iter__(self):
        for item in self.source_datapipe:
            yield self.output_var.apply_ops(item)


#  TODO(VitalyFedyunin): Extract this list from the DFIterDataPipe registred functions
DATAPIPES_OPS = [
    "_dataframes_as_tuples",
    "groupby",
    "_dataframes_filter",
    "map",
    "to_datapipe",
    "shuffle",
    "concat",
    "batch",
    "_dataframes_per_row",
    "_dataframes_concat",
    "_dataframes_shuffle",
]

UNIMPLEMENTED_ATTR = ["__deepcopy__", "__setstate__", "is_shardable", "apply_sharding"]


class Capture:
    # TODO: All operations are shared across entire InitialCapture, need to figure out what if we join two captures

    def __init__(self, schema_df=None):
        self.ctx = {"operations": [], "variables": [], "schema_df": schema_df}

    def __str__(self):
        return self._ops_str()

    def _ops_str(self):
        res = ""
        for op in self.ctx["operations"]:
            if len(res) > 0:
                res += "\n"
            res += str(op)
        return res

    def __getstate__(self):
        # TODO(VitalyFedyunin): Currently can't pickle (why?)
        self.ctx["schema_df"] = None
        for var in self.ctx["variables"]:
            var.calculated_value = None
        state = {}
        for item in self.__dict__:
            state[item] = getattr(self, item)
        return state

    def __setstate__(self, state):
        for k, v in state.items():
            setattr(self, k, v)

    def __getattr__(self, attrname):
        if attrname == "kwarg" or attrname == "kwargs":
            raise RuntimeError("no kwargs!")
        if attrname in ["__deepcopy__"]:
            raise AttributeError
        result = CaptureGetAttr(self, attrname, ctx=self.ctx)
        return result

    def __getitem__(self, key):
        return CaptureGetItem(self, key, ctx=self.ctx)

    def __setitem__(self, key, value):
        self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))

    def __add__(self, add_val):
        res = CaptureAdd(self, add_val, ctx=self.ctx)
        var = CaptureVariable(res, ctx=self.ctx)
        self.ctx["operations"].append(
            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
        )
        return var

    def __sub__(self, add_val):
        res = CaptureSub(self, add_val, ctx=self.ctx)
        var = CaptureVariable(res, ctx=self.ctx)
        self.ctx["operations"].append(
            CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
        )
        return var

    def __mul__(self, add_val):
        res = CaptureMul(self, add_val, ctx=self.ctx)
        var = CaptureVariable(res, ctx=self.ctx)
        t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
        self.ctx["operations"].append(t)
        return var

    def _is_context_empty(self):
        return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0

    def apply_ops_2(self, dataframe):
        # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
        self.ctx["variables"][0].calculated_value = dataframe
        for op in self.ctx["operations"]:
            op.execute()

    @property
    def columns(self):
        self.apply_ops_2(self.ctx["schema_df"])
        value = self.execute()
        return value.columns

    # TODO(VitalyFedyunin): Add tests
    # TODO(VitalyFedyunin): Need to join context if one of them are empty because we used capture

    def __call__(self, *args, **kwargs):
        # TODO: Check if args or kwargs have more than one different context
        if self._is_context_empty():
            # TODO: Allow CaptureA to take context from mock
            for arg in args:
                if isinstance(arg, Capture) and not arg._is_context_empty():
                    self.ctx = arg.ctx
                    break
            if self._is_context_empty():
                for k, v in kwargs.items():
                    if isinstance(k, Capture) and not k._is_context_empty():
                        self.ctx = k.ctx
                        break
                    if isinstance(v, Capture) and not v._is_context_empty():
                        self.ctx = v.ctx
                        break

        res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
        var = CaptureVariable(None, ctx=self.ctx)
        t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
        self.ctx["operations"].append(t)
        return var


class CaptureF(Capture):
    def __init__(self, ctx=None, **kwargs):
        if ctx is None:
            self.ctx = {"operations": [], "variables": []}
        else:
            self.ctx = ctx
        self.kwargs = kwargs


class CaptureA(CaptureF):
    def __str__(self):
        return f"{self.kwargs['name']}"

    def execute(self):
        value = self.kwargs["real_attribute"]
        return value


class CaptureLikeMock:
    def __init__(self, name):
        import unittest.mock as mock

        # TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
        get_target, attribute = mock._get_target(name)  # type: ignore[attr-defined]
        self.get_target = get_target
        self.attribute = attribute
        self.name = name

    def __enter__(self):
        self.save = getattr(self.get_target(), self.attribute)
        capt = CaptureA(name=self.name, real_attribute=self.save)
        setattr(self.get_target(), self.attribute, capt)

    def __exit__(self, *exc_info):
        setattr(self.get_target(), self.attribute, self.save)


class CaptureCall(Capture):
    def __init__(self, callable, ctx=None, **kwargs):
        if ctx is None:
            self.ctx = {"operations": [], "variables": []}
        else:
            self.ctx = ctx
        self.kwargs = kwargs
        self.callable = callable

    def __str__(self):
        return "{callable}({args},{kwargs})".format(
            callable=self.callable, **self.kwargs
        )

    def execute(self):
        # TODO: VitalyFedyunin execute kwargs and maybe nested structures
        executed_args = []
        for arg in self.kwargs["args"]:
            if isinstance(arg, Capture):
                executed_args.append(arg.execute())
            else:
                executed_args.append(arg)
        left = get_val(self.callable)
        return left(*executed_args, **self.kwargs["kwargs"])


class CaptureVariableAssign(CaptureF):
    def __str__(self):
        variable = self.kwargs["variable"]
        value = self.kwargs["value"]
        return f"{variable} = {value}"

    def execute(self):
        self.kwargs["variable"].calculated_value = self.kwargs["value"].execute()


class CaptureVariable(Capture):
    # TODO(VitalyFedyunin): This should be atomic and thread safe
    names_idx = 0

    def __init__(self, value, ctx):
        if CaptureControl.disabled:
            raise RuntimeError("Attempting to create capture variable with capture off")
        self.ctx = ctx
        self.value = value
        self.name = f"var_{CaptureVariable.names_idx}"
        CaptureVariable.names_idx += 1
        self.ctx["variables"].append(self)

    def __str__(self):
        return self.name

    def execute(self):
        return self.calculated_value

    def apply_ops(self, dataframe):
        # TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
        self.ctx["variables"][0].calculated_value = dataframe
        for op in self.ctx["operations"]:
            op.execute()
        return self.calculated_value


class CaptureGetItem(Capture):
    def __init__(self, left, key, ctx):
        self.ctx = ctx
        self.left = left
        self.key = key

    def __str__(self):
        return f"{self.left}[{get_val(self.key)}]"

    def execute(self):
        left = self.left.execute()
        return left[self.key]


class CaptureSetItem(Capture):
    def __init__(self, left, key, value, ctx):
        self.ctx = ctx
        self.left = left
        self.key = key
        self.value = value

    def __str__(self):
        return f"{self.left}[{get_val(self.key)}] = {self.value}"

    def execute(self):
        left = self.left.execute()
        value = self.value.execute()
        left[self.key] = value


class CaptureAdd(Capture):
    def __init__(self, left, right, ctx):
        self.ctx = ctx
        self.left = left
        self.right = right

    def __str__(self):
        return f"{self.left} + {self.right}"

    def execute(self):
        return get_val(self.left) + get_val(self.right)


class CaptureMul(Capture):
    def __init__(self, left, right, ctx):
        self.ctx = ctx
        self.left = left
        self.right = right

    def __str__(self):
        return f"{self.left} * {self.right}"

    def execute(self):
        return get_val(self.left) * get_val(self.right)


class CaptureSub(Capture):
    def __init__(self, left, right, ctx):
        self.ctx = ctx
        self.left = left
        self.right = right

    def __str__(self):
        return f"{self.left} - {self.right}"

    def execute(self):
        return get_val(self.left) - get_val(self.right)


class CaptureGetAttr(Capture):
    def __init__(self, src, name, ctx):
        self.ctx = ctx
        self.src = src
        self.name = name

    def __str__(self):
        return f"{self.src}.{self.name}"

    def execute(self):
        val = get_val(self.src)
        return getattr(val, self.name)


def get_val(capture):
    if isinstance(capture, Capture):
        return capture.execute()
    elif isinstance(capture, str):
        return f'"{capture}"'
    else:
        return capture


class CaptureInitial(CaptureVariable):
    def __init__(self, schema_df=None):
        new_ctx: dict[str, list[Any]] = {
            "operations": [],
            "variables": [],
            "schema_df": schema_df,
        }
        super().__init__(None, new_ctx)
        self.name = f"input_{self.name}"


class CaptureDataFrame(CaptureInitial):
    pass


class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
    def as_datapipe(self):
        return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)

    def raw_iterator(self):
        return self.as_datapipe().__iter__()

    def __iter__(self):
        return iter(self._dataframes_as_tuples())

    def batch(self, batch_size=10, drop_last: bool = False, wrapper_class=DataChunkDF):
        dp = self._dataframes_per_row()._dataframes_concat(batch_size)
        dp = dp.as_datapipe().batch(1, drop_last=drop_last, wrapper_class=wrapper_class)
        dp._dp_contains_dataframe = True
        return dp

    def groupby(
        self,
        group_key_fn,
        *,
        buffer_size=10000,
        group_size=None,
        guaranteed_group_size=None,
        drop_remaining=False,
    ):
        dp = self._dataframes_per_row()
        dp = dp.as_datapipe().groupby(
            group_key_fn,
            buffer_size=buffer_size,
            group_size=group_size,
            guaranteed_group_size=guaranteed_group_size,
            drop_remaining=drop_remaining,
        )
        return dp

    def shuffle(self, *args, **kwargs):
        return self._dataframes_shuffle(*args, **kwargs)

    def filter(self, *args, **kwargs):
        return self._dataframes_filter(*args, **kwargs)

    def collate(self, *args, **kwargs):
        raise RuntimeError("Can't collate unbatched DataFrames stream")

    def __getattr__(self, attrname):  # ?
        if attrname in UNIMPLEMENTED_ATTR:
            raise AttributeError("Attempting to get ", attrname)
        if attrname in DATAPIPES_OPS:
            return (self.as_datapipe()).__getattr__(attrname)
        return super().__getattr__(attrname)


@functional_datapipe("trace_as_dataframe")
class DataFrameTracer(CaptureDataFrameWithDataPipeOps, IterDataPipe):  # type: ignore[misc]
    source_datapipe: Optional[Any] = None

    # TODO(VitalyFedyunin): Must implement all special functions of datapipes

    def set_shuffle_settings(self, *args, **kwargs):
        pass

    def is_shardable(self):
        return False

    def __init__(self, source_datapipe, schema_df=None):
        self.source_datapipe = source_datapipe
        if schema_df is None:
            schema_df = next(iter(self.source_datapipe))
        super().__init__(schema_df=schema_df)
