# mypy: allow-untyped-defs
# mypy: disable-error-code=arg-type
import importlib
import inspect

from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
from torch.onnx._internal import jit_utils, registration


def register_quantized_ops(domain: str, version: int):
    # Register all quantized ops
    module = importlib.import_module("torch.onnx.symbolic_caffe2")
    quant_version_ops = inspect.getmembers(module)
    aten_q_ops = {
        "relu",
        "_empty_affine_quantized",
        "dequantize",
        "quantize_per_tensor",
        "upsample_nearest2d",
        "avg_pool2d",
        "reshape",
        "slice",
        "cat",
        "max_pool2d",
        "sigmoid",
    }
    for op, func in quant_version_ops:
        name = f"{domain}::{op}"
        if inspect.isfunction(func) and not registration.registry.is_registered_op(
            name, version
        ):
            if op in aten_q_ops:
                # Override the builtin aten ops
                registration.registry.register(
                    f"aten::{op}", version, func, custom=True
                )
            registration.registry.register(name, version, func)


def _permute_helper(g: jit_utils.GraphContext, input, axes):
    quant_args = {
        "axes_i": axes,
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    output = g.op("_caffe2::Int8Transpose", input, **quant_args)
    symbolic_helper._quantized_ops.add(output)
    return output


def nchw2nhwc(g: jit_utils.GraphContext, input):
    axes = [0, 2, 3, 1]
    return _permute_helper(g, input, axes)


def nhwc2nchw(g: jit_utils.GraphContext, input):
    axes = [0, 3, 1, 2]
    return _permute_helper(g, input, axes)


def linear_prepack(g: jit_utils.GraphContext, weight, bias):
    # Mapping to a dummy caffe2 prepack node.
    # During the onnx -> c2 conversion we can look up original weight and bias
    # from this node
    output = g.op("_caffe2::WeightPrepack", weight, bias)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "v", "v", "f", "i")
def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point):
    kwargs = {
        "Y_scale_f": scale,
        "Y_zero_point_i": zero_point,
    }
    output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


def conv_prepack(
    g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
):
    # Mapping to a dummy caffe2 prepack node.
    # During the onnx -> c2 conversion we can look up original weight and bias
    # from this node
    output = g.op("_caffe2::WeightPrepack", input, weight, bias)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
def conv2d(
    g: jit_utils.GraphContext,
    input,
    weight,
    bias,
    stride,
    padding,
    dilation,
    groups,
    scale,
    zero_point,
):
    kernel_size = weight.node()["shape"][1:3]
    kwargs = {
        "strides_i": stride,
        "pads_i": padding + padding,
        "dilations_i": dilation,
        "group_i": groups,
        "kernels_i": kernel_size,
        "order_s": "NHWC",
        "Y_scale_f": scale,
        "Y_zero_point_i": zero_point,
    }
    output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
def conv2d_relu(
    g: jit_utils.GraphContext,
    input,
    weight,
    bias,
    stride,
    padding,
    dilation,
    groups,
    scale,
    zero_point,
):
    kernel_size = weight.node()["shape"][1:3]
    kwargs = {
        "strides_i": stride,
        "pads_i": padding + padding,
        "dilations_i": dilation,
        "group_i": groups,
        "kernels_i": kernel_size,
        "order_s": "NHWC",
        "Y_scale_f": scale,
        "Y_zero_point_i": zero_point,
    }
    output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "v", "f", "i")
def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point):
    kwargs = {
        "Y_scale_f": scale,
        "Y_zero_point_i": zero_point,
    }
    output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v")
def relu(g: jit_utils.GraphContext, input):
    if input not in symbolic_helper._quantized_ops:
        return opset9.relu(g, input)
    kwargs = {
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    output = g.op("_caffe2::Int8Relu", input, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "f", "i", "t")
def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
    kwargs = {
        "Y_scale_f": scale,
        "Y_zero_point_i": zero_point,
    }
    output = g.op("_caffe2::Int8Quantize", input, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v")
def dequantize(g: jit_utils.GraphContext, input):
    return g.op("_caffe2::Int8Dequantize", input)


@symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t")
def _empty_affine_quantized(
    g: jit_utils.GraphContext,
    input,
    shape,
    scale,
    zero_point,
    dtype,
    pin_memory,
    memory_format,
    layout,
):
    return input


def upsample_nearest2d(
    g: jit_utils.GraphContext,
    input,
    output_size,
    align_corners=None,
    scales_h=None,
    scales_w=None,
):
    if input not in symbolic_helper._quantized_ops:
        return opset9.upsample_nearest2d(g, input, output_size, align_corners)  # type: ignore[attr-defined]

    output_size = symbolic_helper._parse_arg(output_size, "is")
    kwargs = {
        "output_size_i": output_size,
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    input = nchw2nhwc(g, input)
    output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs)
    output = nhwc2nchw(g, output)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
def max_pool2d(
    g: jit_utils.GraphContext,
    input,
    kernel_size,
    stride,
    padding,
    dilation,
    ceil_mode,
):
    if input not in symbolic_helper._quantized_ops:
        return opset9.max_pool2d(  # type: ignore[attr-defined]
            g, input, kernel_size, stride, padding, dilation, ceil_mode
        )
    kwargs = {
        "strides_i": stride,
        "pads_i": padding + padding,
        "kernel_i": kernel_size[0],
        "order_s": "NHWC",
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    input = nchw2nhwc(g, input)
    output = g.op("_caffe2::Int8MaxPool", input, **kwargs)
    output = nhwc2nchw(g, output)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
def avg_pool2d(
    g: jit_utils.GraphContext,
    input,
    kernel_size,
    stride,
    padding,
    ceil_mode,
    count_include_pad,
    divisor_override=None,
):
    if input not in symbolic_helper._quantized_ops:
        return opset9.avg_pool2d(  # type: ignore[attr-defined]
            g,
            input,
            kernel_size,
            stride,
            padding,
            ceil_mode,
            count_include_pad,
            divisor_override,
        )
    kwargs = {
        "strides_i": stride,
        "pads_i": padding + padding,
        "kernel_i": kernel_size[0],
        "order_s": "NHWC",
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    input = nchw2nhwc(g, input)
    output = g.op("_caffe2::Int8AveragePool", input, **kwargs)
    output = nhwc2nchw(g, output)
    symbolic_helper._quantized_ops.add(output)
    return output


def reshape(g: jit_utils.GraphContext, input, shape):
    if input not in symbolic_helper._quantized_ops:
        return opset9.reshape(g, input, shape)

    kwargs = {
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v", "v", "v", "v", "i")
def slice(g: jit_utils.GraphContext, input, dim, start, end, step):
    if input not in symbolic_helper._quantized_ops:
        return opset9.slice(g, input, dim, start, end, step)

    if step != 1:
        raise RuntimeError("ONNX quantized slice export only works for step 1.")
    start = symbolic_helper._parse_arg(start, "i")
    end = symbolic_helper._parse_arg(end, "i")
    dim = symbolic_helper._parse_arg(dim, "i")

    kwargs = {
        "start_idx_i": start,
        "end_idx_i": end,
        "dim_i": dim,
        "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
        "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
    }
    output = g.op("_caffe2::Int8Slice", input, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None):
    tensors = symbolic_helper._unpack_list(tensor_list)
    input = tensors[0]
    if input not in symbolic_helper._quantized_ops:
        return opset9.cat(g, tensor_list, dim)

    dim = symbolic_helper._parse_arg(dim, "i")
    kwargs = {
        "Y_scale_f": tensors[0].node()["Y_scale"],
        "Y_zero_point_i": tensors[0].node()["Y_zero_point"],
    }
    output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output


@symbolic_helper.parse_args("v")
def sigmoid(g: jit_utils.GraphContext, input):
    if input not in symbolic_helper._quantized_ops:
        return opset9.sigmoid(g, input)
    # Caffe2 expects the output scale to be 1/2^8
    # and output zero_point to be 0 (quint8 type)
    out_scale = 1.0 / 256
    zero_point = 0
    kwargs = {
        "Y_scale_f": out_scale,
        "Y_zero_point_i": zero_point,
    }
    output = g.op("_caffe2::Int8Sigmoid", input, **kwargs)
    symbolic_helper._quantized_ops.add(output)
    return output
