/**

 * Copyright (c) 2026 Huawei Technologies Co., Ltd.

 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of

 * CANN Open Software License Agreement Version 2.0 (the "License").

 * Please refer to the License for details. You may not use this file except in compliance with the License.

 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

 * See LICENSE in the root of the software repository for the full text of the License.

 */



#include "acl/acl.h"

#include "aclnnop/aclnn_interleave_rope.h"

#include <iostream>

#include <vector>



#define CHECK_RET(cond, return_expr) \

  do {                               \

    if (!(cond)) {                   \

      return_expr;                   \

    }                                \

  } while (0)



#define LOG_PRINT(message, ...)     \

  do {                              \

    printf(message, ##__VA_ARGS__); \

  } while (0)



int64_t GetShapeSize(const std::vector<int64_t>& shape) {

    int64_t shape_size = 1;

    for (auto i : shape) {

        shape_size *= i;

    }

    return shape_size;

}



std::vector<aclFloat16> ConvertToFloat16(const std::vector<float>& data) {

    std::vector<aclFloat16> converted;

    converted.reserve(data.size());

    for (float value : data) {

        converted.push_back(aclFloatToFloat16(value));

    }

    return converted;

}



int Init(int32_t deviceId, aclrtStream* stream) {

    // 固定写法,资源初始化

    auto ret = aclInit(nullptr);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);

    ret = aclrtSetDevice(deviceId);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);

    ret = aclrtCreateStream(stream);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);

    return 0;

}



template <typename T>

int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr,

                    aclDataType dataType, aclTensor** tensor) {

    auto size = GetShapeSize(shape) * sizeof(T);

    // 调用aclrtMalloc申请device侧内存

    auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret);



    // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上

    ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret);



    // 计算连续tensor的strides

    std::vector<int64_t> strides(shape.size(), 1);

    for (int64_t i = shape.size() - 2; i >= 0; i--) {

        strides[i] = shape[i + 1] * strides[i + 1];

    }



    // 调用aclCreateTensor接口创建aclTensor

    *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND,

                              shape.data(), shape.size(), *deviceAddr);

    return 0;

}



int main() {

    // 1. 固定写法,device/stream初始化, 参考AscendCL对外接口列表

    // 根据自己的实际device填写deviceId

    int32_t deviceId = 0;

    aclrtStream stream;

    auto ret = Init(deviceId, &stream);

    // check根据自己的需要处理

    CHECK_RET(ret == 0, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret);

    // 2. 构造输入与输出,需要根据API的接口定义构造

    // interleave_rope约束:D=64, N=1

    std::vector<int64_t> xShape = {1, 1, 1, 64};

    std::vector<int64_t> cosShape = {1, 1, 1, 64};

    std::vector<int64_t> sinShape = {1, 1, 1, 64};

    std::vector<int64_t> outShape = {1, 1, 1, 64};



    void* xDeviceAddr = nullptr;

    void* cosDeviceAddr = nullptr;

    void* sinDeviceAddr = nullptr;

    void* outDeviceAddr = nullptr;

    aclTensor* x = nullptr;

    aclTensor* cos = nullptr;

    aclTensor* sin = nullptr;

    aclTensor* out = nullptr;



    std::vector<aclFloat16> xHostData = ConvertToFloat16({

        74, 54, 84, 125, 23, 78, 37, 72, 27, 98, 34, 107, 29, 23, 54, 60,

        70, 49, 119, 54, 29, 54, 41, 99, 27, 62, 5, 46, 108, 39, 24, 123,

        33, 82, 6, 40, 88, 24, 6, 116, 38, 119, 110, 5, 30, 79, 87, 18,

        29, 100, 90, 24, 21, 93, 63, 68, 34, 112, 119, 48, 74, 43, 85, 64

    });

    std::vector<aclFloat16> cosHostData = ConvertToFloat16({

        41, 37, 17, 25, 49, 25, 22, 24, 110, 120, 107, 3, 82, 66, 75, 86,

        85, 115, 110, 56, 52, 39, 86, 23, 36, 71, 20, 73, 113, 25, 114, 56,

        125, 80, 95, 82, 31, 63, 99, 62, 23, 55, 30, 99, 42, 121, 15, 24,

        97, 87, 81, 67, 43, 21, 13, 9, 33, 29, 117, 10, 114, 61, 98, 15

    });

    std::vector<aclFloat16> sinHostData = ConvertToFloat16({

        46, 56, 56, 101, 66, 10, 96, 16, 86, 57, 102, 66, 12, 105, 76, 58,

        90, 6, 79, 128, 126, 82, 41, 3, 45, 7, 66, 4, 46, 22, 31, 26,

        37, 63, 97, 84, 91, 90, 47, 77, 90, 34, 41, 83, 91, 108, 120, 13,

        90, 32, 85, 37, 119, 31, 51, 82, 122, 125, 7, 116, 121, 108, 38, 56

    });

    std::vector<aclFloat16> outHostData(64, aclFloatToFloat16(0.0f));



    // 创建x aclTensor

    ret = CreateAclTensor(xHostData, xShape, &xDeviceAddr, aclDataType::ACL_FLOAT16, &x);

    CHECK_RET(ret == ACL_SUCCESS, return ret);

    // 创建cos aclTensor

    ret = CreateAclTensor(cosHostData, cosShape, &cosDeviceAddr, aclDataType::ACL_FLOAT16, &cos);

    CHECK_RET(ret == ACL_SUCCESS, return ret);

    // 创建sin aclTensor

    ret = CreateAclTensor(sinHostData, sinShape, &sinDeviceAddr, aclDataType::ACL_FLOAT16, &sin);

    CHECK_RET(ret == ACL_SUCCESS, return ret);

    // 创建out aclTensor

    ret = CreateAclTensor(outHostData, outShape, &outDeviceAddr, aclDataType::ACL_FLOAT16, &out);

    CHECK_RET(ret == ACL_SUCCESS, return ret);



    // 3. 调用CANN算子库API,需要修改为具体的API

    uint64_t workspaceSize = 0;

    aclOpExecutor* executor;

    // 调用aclnnInterleaveRope第一段接口

    ret = aclnnInterleaveRopeGetWorkspaceSize(x, cos, sin, out, &workspaceSize, &executor);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnInterleaveRopeGetWorkspaceSize failed. ERROR: %d\n", ret); return ret);

    // 根据第一段接口计算出的workspaceSize申请device内存

    void* workspaceAddr = nullptr;

    if (workspaceSize > 0) {

        ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST);

        CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret);

    }

    // 调用aclnnInterleaveRope第二段接口

    ret = aclnnInterleaveRope(workspaceAddr, workspaceSize, executor, stream);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnInterleaveRope failed. ERROR: %d\n", ret); return ret);

    // 4. 固定写法,同步等待任务执行结束

    ret = aclrtSynchronizeStream(stream);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret);

    // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改

    auto size = GetShapeSize(outShape);

    std::vector<aclFloat16> resultData(size, aclFloatToFloat16(0.0f));

    ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), outDeviceAddr,

                      size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST);

    CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return ret);



    for (int64_t i = 0; i < size; i++) {

        LOG_PRINT("result[%ld] is: %f\n", i, aclFloat16ToFloat(resultData[i]));

    }



    // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改

    aclDestroyTensor(x);

    aclDestroyTensor(cos);

    aclDestroyTensor(sin);

    aclDestroyTensor(out);



    // 7. 释放device 资源

    aclrtFree(xDeviceAddr);

    aclrtFree(cosDeviceAddr);

    aclrtFree(sinDeviceAddr);

    aclrtFree(outDeviceAddr);

    if (workspaceSize > 0) {

      aclrtFree(workspaceAddr);

    }

    aclrtDestroyStream(stream);

    aclrtResetDevice(deviceId);

    aclFinalize();



    return 0;

}