# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

import logging
from distutils.version import LooseVersion

import torch
import torch as th
import torch.nn as nn

if LooseVersion(torch.__version__) < LooseVersion("1.8.0"):
    logging.warning(
        f"torch.fx requires version higher than 1.8.0. But You are using an old version PyTorch {torch.__version__}. "
    )


def count_clamp(input_shapes, output_shapes):
    """Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
    return 0


def count_mul(input_shapes, output_shapes):
    """Returns the number of elements in the first output shape."""
    return output_shapes[0].numel()


def count_matmul(input_shapes, output_shapes):
    """Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
    in_shape = input_shapes[0]
    out_shape = output_shapes[0]
    in_features = in_shape[-1]
    num_elements = out_shape.numel()
    return in_features * num_elements


def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
    """Calculates the total FLOPs for a linear layer, including bias operations if specified."""
    flops = count_matmul(input_shapes, output_shapes)
    if "bias" in kwargs:
        flops += output_shapes[0].numel()
    return flops


from .vision.calc_func import calculate_conv


def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
    """Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
    `calculate_conv`.
    """
    inputs, weight, bias, stride, padding, dilation, groups = args
    if len(input_shapes) == 2:
        x_shape, k_shape = input_shapes
    elif len(input_shapes) == 3:
        x_shape, k_shape, b_shape = input_shapes
    out_shape = output_shapes[0]

    kernel_parameters = k_shape[2:].numel()
    bias_op = 0  # check it later
    in_channel = x_shape[1]

    total_ops = calculate_conv(bias_op, kernel_parameters, out_shape.numel(), in_channel, groups).item()
    return int(total_ops)


def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
    """Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
    return count_matmul(input_shapes, output_shapes)


def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
    """Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
    return 0


def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
    """Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
    bias_op = 1 if module.bias is not None else 0
    out_shape = output_shapes[0]

    in_channel = module.in_channels
    groups = module.groups
    kernel_ops = module.weight.shape[2:].numel()
    total_ops = calculate_conv(bias_op, kernel_ops, out_shape.numel(), in_channel, groups).item()
    return int(total_ops)


def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
    """Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
    assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
    y = output_shapes[0]
    return 2 * y.numel()


zero_ops = (
    nn.ReLU,
    nn.ReLU6,
    nn.Dropout,
    nn.MaxPool2d,
    nn.AvgPool2d,
    nn.AdaptiveAvgPool2d,
)

count_map = {
    nn.Linear: count_nn_linear,
    nn.Conv2d: count_nn_conv2d,
    nn.BatchNorm2d: count_nn_bn2d,
    "function linear": count_fn_linear,
    "clamp": count_clamp,
    "built-in function add": count_zero_ops,
    "built-in method fl": count_zero_ops,
    "built-in method conv2d of type object": count_fn_conv2d,
    "built-in function mul": count_mul,
    "built-in function truediv": count_mul,
}

for k in zero_ops:
    count_map[k] = count_zero_ops

missing_maps = {}

from torch.fx import symbolic_trace
from torch.fx.passes.shape_prop import ShapeProp

from .utils import prRed, prYellow


def null_print(*args, **kwargs):
    """A no-op print function that takes any arguments without performing any actions."""
    return


def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
    """Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
    gm: torch.fx.GraphModule = symbolic_trace(mod)
    ShapeProp(gm).propagate(input)

    fprint = null_print
    if verbose:
        fprint = print

    v_maps = {}
    total_flops = 0

    for node in gm.graph.nodes:
        # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}")
        fprint(f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}")
        # node_op_type = str(node.target).split(".")[-1]
        node_flops = None

        input_shapes = []
        fprint("input_shape:", end="\t")
        for arg in node.args:
            if str(arg) not in v_maps:
                continue
            fprint(f"{v_maps[str(arg)]}", end="\t")
            input_shapes.append(v_maps[str(arg)])
        fprint()
        fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
        output_shapes = [node.meta["tensor_meta"].shape]
        if node.op in ["output", "placeholder"]:
            node_flops = 0
        elif node.op == "call_function":
            # torch internal functions
            key = str(node.target).split("at")[0].replace("<", "").replace(">", "").strip()
            if key in count_map:
                node_flops = count_map[key](input_shapes, output_shapes, *node.args, **node.kwargs)
            else:
                missing_maps[key] = (node.op, key)
                prRed(f"|{key}| is missing")
        elif node.op == "call_method":
            # torch internal functions
            # fprint(str(node.target) in count_map, str(node.target), count_map.keys())
            key = str(node.target)
            if key in count_map:
                node_flops = count_map[key](input_shapes, output_shapes)
            else:
                missing_maps[key] = (node.op, key)
                prRed(f"{key} is missing")
        elif node.op == "call_module":
            # torch.nn modules
            # m = getattr(mod, node.target, None)
            m = mod.get_submodule(node.target)
            key = type(m)
            fprint(type(m), type(m) in count_map)
            if type(m) in count_map:
                node_flops = count_map[type(m)](m, input_shapes, output_shapes)
            else:
                missing_maps[key] = (node.op,)
                prRed(f"{key} is missing")
            print("module type:", type(m))
            if isinstance(m, zero_ops):
                print("weight_shape: None")
            else:
                print(type(m))
                print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")

        v_maps[str(node.name)] = node.meta["tensor_meta"].shape
        if node_flops is not None:
            total_flops += node_flops
        prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}")
        fprint("==" * 40)

    if len(missing_maps.keys()) > 0:
        from pprint import pprint

        print("Missing operators: ")
        pprint(missing_maps)
    return total_flops


if __name__ == "__main__":

    class MyOP(nn.Module):
        def forward(self, input):
            """Performs forward pass on given input data."""
            return input / 1

    class MyModule(torch.nn.Module):
        def __init__(self):
            """Initializes MyModule with two linear layers and a custom MyOP operator."""
            super().__init__()
            self.linear1 = torch.nn.Linear(5, 3)
            self.linear2 = torch.nn.Linear(5, 3)
            self.myop = MyOP()

        def forward(self, x):
            """Applies two linear transformations to the input tensor, clamps the second, then combines and processes
            with MyOP operator.
            """
            out1 = self.linear1(x)
            out2 = self.linear2(x).clamp(min=0.0, max=1.0)
            return self.myop(out1 + out2)

    net = MyModule()
    data = th.randn(20, 5)
    flops = fx_profile(net, data, verbose=False)
    print(flops)
