/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved.
 */
#include "kernel_operator.h"
#include "kernel_tiling/kernel_tiling.h"
#include "scatter_add_grad_base.h"
#include "scatter_add_grad.h"
#include "scatter_add_grad_line.h"
#include "scatter_add_grad_large.h"

using namespace ScatterAddGradNS;

extern "C" __global__ __aicore__ void scatter_add_grad_v1(GM_ADDR grad_out, GM_ADDR index, GM_ADDR grad_in, GM_ADDR workspace, GM_ADDR tiling)
{
    GET_TILING_DATA(tilingData, tiling);
    if (TILING_KEY_IS(2)) {
        ScatterAddGradNS::ScatterAddGradLine<float> op;
        op.InitLine(grad_out, index, grad_in, &tilingData);
        op.Process();
    }
    if (TILING_KEY_IS(1)) {
        ScatterAddGradNS::ScatterAddGradV1<float> op;
        op.Init(grad_out, index, grad_in, &tilingData);
        op.Process();
    }
    else if (TILING_KEY_IS(3)) {
        ScatterAddGradNS::ScatterAddGradLarge<float> op;
        op.Init(grad_out, index, grad_in, &tilingData);
        op.Process();
    }
}