import torch
from mindspeed.op_builder.builder import MindSpeedOpBuilder


class GroupedMatMulAllReduceOpBuilder(MindSpeedOpBuilder):
    OP_NAME = "npu_grouped_mat_mul_all_reduce"
    TORCH_MAJOR, TORCH_MINOR = map(int, torch.__version__.split('.')[:2])

    def __init__(self):
        super(GroupedMatMulAllReduceOpBuilder, self).__init__(self.OP_NAME)

    def sources(self):
        return ['ops/csrc/cann/npu_grouped_mat_mul_all_reduce.cpp']

    def include_paths(self):
        paths = super().include_paths()
        paths += ['ops/csrc/cann/inc']
        return paths

    def cxx_args(self):
        args = super().cxx_args()
        args += [
            '-Wno-sign-compare',
            '-Wno-deprecated-declarations',
            '-Wno-return-type',
            "-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
        ]
        if self.TORCH_MAJOR >= 2 and self.TORCH_MINOR >= 1:
            cpp_std = " -std=c++17"
            compile_maroc = " -D__TORCH_2__"
        else:
            cpp_std = " -std=c++14"
            compile_maroc = " -D__TORCH_1__"
        args.append(cpp_std)
        args.append(compile_maroc)
        
        return args