# Copyright (c) 2024 Huawei Technologies Co., Ltd
# Copyright (c) 2019, Facebook CORPORATION.
# All rights reserved.
#
# Licensed under the BSD 3-Clause License  (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Dict, Sequence
import copy
from dataclasses import dataclass

from torchnpugen.model import BaseTy, SchemaKind, BaseType, Argument, NativeFunction, ListType
from torchnpugen.context import native_function_manager
from torchnpugen.api.types import NativeSignature
from torchnpugen.api import cpp


STRUCTURED_GEN_OPAPI_ALLOWED_KEYS = {'new_params', 'exec'}


def filt_input_tensor(arguments: Sequence[Argument]) -> List[str]:
    input_tensors = []
    for arg in arguments:
        if isinstance(arg.type, BaseType) and arg.type.name == BaseTy.Tensor:
            input_tensors.append(arg.name)
        if isinstance(arg.type, ListType) and arg.type.elem.name == BaseTy.Tensor:
            input_tensors.append(f'{arg.name}[0]')
    return input_tensors


@dataclass(frozen=True)
class ResInfo:
    name: str
    size: str
    dtype: str
    option: str
    infer_name: str = None

    @staticmethod
    def parse(results: Dict[str, Dict[str, str]], f: 'NativeFunction') -> List['ResInfo']:
        kind = f.func.kind()
        tensor_number = sum(map(lambda x: x.type.name == BaseTy.Tensor, f.func.returns))
        if len(results) != tensor_number and kind == SchemaKind.functional:
            raise RuntimeError(
                f"The number of result info in yaml is {len(results)}."
                f"That does not match {f.func.name}'s returns tensor number {tensor_number}"
            )

        arguments = f.func.arguments.flat_all

        input_tensors = filt_input_tensor(arguments)
        res_infos = []

        if kind == SchemaKind.out:
            result_names = cpp.return_names(f)
        else:
            result_names = list(results)

        for name in result_names:
            info = results.get(name, None)
            if kind == SchemaKind.out and info is None:
                size, dtype = name, name
                infer_name = None
            elif info is None:
                raise RuntimeError(f"The '{name}' is missing in results for {f.func.name}")
            elif kind == SchemaKind.inplace:
                size = info.pop('size', None)
                dtype = info.pop('dtype', None)
                infer_name = info.pop('name', None)
            else:
                size = info.pop('size', None)
                dtype = info.pop('dtype', None)
                size = name if size is None and kind == SchemaKind.out else size
                dtype = name if dtype is None and kind == SchemaKind.out else dtype
                if size is None:
                    raise RuntimeError(f"The {name}'s size  is None in {f.func.name}")
                if dtype is None:
                    raise RuntimeError(f"The {name}'s dtype  is None in {f.func.name}")
                infer_name = info.pop('name', None)

            if size in input_tensors:
                size_formula = f'{size}.sizes()'
            else:
                size_formula = size
            if dtype in input_tensors:
                dtype_formula = f'{dtype}.scalar_type()'
            else:
                dtype_formula = dtype

            res_infos.append(
                ResInfo(
                    name=name, size=size_formula, dtype=dtype_formula, option=input_tensors[0], infer_name=infer_name
                )
            )

        return res_infos


@dataclass(frozen=True)
class CpuScalarOp:
    param: str
    exec_cmd: str

    @staticmethod
    def parse(op_list: List[Dict[str, str]]) -> List['CpuScalarOp']:
        ops = []
        for op in op_list:
            param = op.get('param', '')
            exec_cmd = op.get('exec', '')
            ops.append(CpuScalarOp(param=param, exec_cmd=exec_cmd))
        return ops


@dataclass(frozen=True)
class StructInfo:
    name: str
    structured_inherit: 'NativeFunction' = None
    aclnn_name: str = None
    func: 'NativeFunction' = None
    cmd_args: str = None
    return_args: str = None
    acl_op: bool = None
    results: List[ResInfo] = None
    new_params: Dict[str, str] = None
    integral_identity_tensor: str = None
    cpu_scalar_h2d: List[str] = None
    cpu_scalar_op: List[CpuScalarOp] = None
    use_structured_meta: bool = False

    @staticmethod
    def from_yaml(
        es: Sequence[Dict[str, object]],
        native_functions: Sequence[NativeFunction],
    ) -> "List[StructInfo]":
        """
        Parse a StructInfo from a dictionary as directly parsed
        from op_plugin_functions.yaml
        """
        functions_by_schema: Dict[str, NativeFunction] = {}

        for function in native_functions:
            if str(function.func) in functions_by_schema:
                raise RuntimeError(f"{function.func} has multiple definitions in op_plugin_functions.yaml")

            functions_by_schema[str(function.func)] = function

        def gen_func_name(schema: str) -> str:
            return schema.split('(')[0]

        def get_return_arguments(func: NativeFunction, results: List[ResInfo]) -> List[str]:
            kind = func.func.kind()
            if kind == SchemaKind.inplace:
                return [func.func.arguments.self_arg.argument.name]
            if kind == SchemaKind.out and len(results) == 0:
                return cpp.return_names(func)
            return [result.name for result in results]

        def format_return_args(func: NativeFunction, return_argument: List[str]) -> str:
            kind = func.func.kind()
            if len(return_argument) == 0:
                return ''
            if len(return_argument) == 1:
                return ''.join([' ', return_argument[0]])
            if kind == SchemaKind.out:
                return ''.join([' ', f"std::forward_as_tuple({', '.join(return_argument)})"])
            move_args = ', '.join(f'std::move({arg})' for arg in return_argument)
            return ''.join([' ', f'std::make_tuple({move_args})'])

        funcname_map: Dict[str, NativeFunction] = {}
        struct_map: Dict[str, Dict] = {}
        for e in es:
            e.pop('acl_op', None)
            e.pop('op_api', None)
            schema_str = e.get('func', None)
            op_name = gen_func_name(schema_str)

            funcname_map[op_name] = functions_by_schema.get(schema_str)
            struct_map[op_name] = copy.deepcopy(e)

        struct_infos: List[StructInfo] = []
        for e in es:
            schema_str = e.pop('func', None)
            defn_name = gen_func_name(schema_str)
            schema_function = functions_by_schema.get(schema_str)

            if not schema_function:
                avail = "\n".join(k for k in functions_by_schema if gen_func_name(k) == defn_name)
                raise RuntimeError(
                    f"could not find ATen function for schema: {schema_str} .  Available signatures:\n{avail}"
                )

            func_kind = schema_function.func.kind()

            if 'op_api' not in schema_function.impl_ns:
                raise RuntimeError(
                    f"The Aten function {schema_str} has no op_opi implement in op_plugin_functions yaml"
                )
            acl_op = 'acl_op' in schema_function.impl_ns

            gen_opapi_info = e.get('gen_opapi')
            if not isinstance(gen_opapi_info, dict):
                raise RuntimeError(f"The Aten function {schema_str} has invalid gen_opapi configuration")

            structured_inherit = gen_opapi_info.pop('structured_inherit', None)

            if structured_inherit is not None:
                if func_kind == SchemaKind.inplace:
                    struct_info = StructInfo(
                        name=defn_name,
                        func=schema_function,
                        structured_inherit=funcname_map.get(structured_inherit),
                        acl_op=acl_op,
                    )
                    struct_infos.append(struct_info)
                    continue
                gen_opapi_info = struct_map.get(structured_inherit)
                if gen_opapi_info is None:
                    raise RuntimeError(f'The structured_inherit func {structured_inherit} is None')
                gen_opapi_info = gen_opapi_info.get('gen_opapi')

            integral_identity_tensor = gen_opapi_info.pop('integral_identity_tensor', None)
            aclnn_arguments = gen_opapi_info.pop('exec', None)
            if aclnn_arguments is None:
                raise RuntimeError(f"The Aten function {schema_str}'s gen_opapi has no exec configuration")

            aclnn_arguments_list = [argument.strip() for argument in aclnn_arguments.split(',')]
            aclnn_name = aclnn_arguments_list[0]
            new_params_dict = gen_opapi_info.pop('new_params', dict())

            cpu_scalar_h2d = gen_opapi_info.pop('cpu_scalar_h2d', None)
            cpu_scalar_op_raw = gen_opapi_info.pop('cpu_scalar_op', None)
            cpu_scalar_op = CpuScalarOp.parse(cpu_scalar_op_raw) if cpu_scalar_op_raw else None

            cmd_args_expand = False
            if len(aclnn_arguments_list) == 1:
                cmd_args_expand = True
                with native_function_manager(schema_function):
                    sig = NativeSignature(schema_function.func, prefix='', symint=False)
                    for a in sig.arguments():
                        aclnn_arguments_list.append(a.name)

            aclnn_arguments = ', '.join(aclnn_arguments_list)

            if func_kind == SchemaKind.out:
                output_names = cpp.return_names(schema_function)
                for key in gen_opapi_info:
                    if key not in output_names:
                        raise ValueError(f"Result infomations contains invalid key: {key} in {schema_str}")

            use_structured_meta = bool(e.get('structured', False))
            results = [] if use_structured_meta else ResInfo.parse(gen_opapi_info, schema_function)
            return_argument = get_return_arguments(schema_function, results)
            return_args = format_return_args(schema_function, return_argument)

            if cmd_args_expand and func_kind == SchemaKind.functional:
                aclnn_arguments = f"{aclnn_arguments}, {', '.join(return_argument)}"

            struct_info = StructInfo(
                name=defn_name,
                aclnn_name=aclnn_name,
                cmd_args=aclnn_arguments,
                func=schema_function,
                results=results,
                return_args=return_args,
                acl_op=acl_op,
                new_params=new_params_dict,
                integral_identity_tensor=integral_identity_tensor,
                cpu_scalar_h2d=cpu_scalar_h2d,
                cpu_scalar_op=cpu_scalar_op,
                use_structured_meta=use_structured_meta
            )
            struct_infos.append(struct_info)

        return struct_infos