#include "test_common.h"



#include <gtest/gtest.h>

#include <pto/pto-inst.hpp>



using namespace std;

using namespace PtoTestCommon;



template <uint32_t caseId>

void launchTPOWSTestCase(void *out, void *src, void *scalar, aclrtStream stream);



class TPOWSTest : public testing::Test {

protected:

    void SetUp() override

    {}

    void TearDown() override

    {}

};



std::string GetGoldenDir()

{

    const testing::TestInfo *testInfo = testing::UnitTest::GetInstance()->current_test_info();

    const std::string caseName = testInfo->name();

    std::string suiteName = testInfo->test_suite_name();

    std::string fullPath = "../" + suiteName + "." + caseName;

    return fullPath;

}



template <typename T, int oRow, int oCol>

inline void InitDstDevice(T *dstDevice)

{

    constexpr int size = oRow * oCol;

    for (int i = 0; i < size; ++i) {

        dstDevice[i] = T{0};

    }

}



template <uint32_t caseId, typename T, int validRow, int validCol, int iRow = validRow, int iCol = validCol,

          int oRow = validRow, int oCol = validCol>

bool TPowSTestFramework()

{

    aclInit(nullptr);

    aclrtSetDevice(0);



    aclrtStream stream;

    aclrtCreateStream(&stream);



    size_t dstByteSize = oRow * oCol * sizeof(T);

    size_t srcByteSize = iRow * iCol * sizeof(T);

    size_t readSize = 0;

    T *dstHost;

    T *srcHost;

    T *dstDevice;

    T *srcDevice;

    T scalar;



    aclrtMallocHost((void **)(&dstHost), dstByteSize);

    aclrtMallocHost((void **)(&srcHost), srcByteSize);



    aclrtMalloc((void **)&dstDevice, dstByteSize, ACL_MEM_MALLOC_HUGE_FIRST);

    aclrtMalloc((void **)&srcDevice, srcByteSize, ACL_MEM_MALLOC_HUGE_FIRST);



    InitDstDevice<T, oRow, oCol>(dstDevice);

    ReadFile(GetGoldenDir() + "/input.bin", readSize, srcHost, srcByteSize);



    std::ifstream file(GetGoldenDir() + "/scalar.bin", std::ios::binary);



    file.read(reinterpret_cast<char *>(&scalar), sizeof(T));

    file.close();

    aclrtMemcpy(srcDevice, srcByteSize, srcHost, srcByteSize, ACL_MEMCPY_HOST_TO_DEVICE);

    launchTPOWSTestCase<caseId>(dstDevice, srcDevice, &scalar, stream);

    aclrtSynchronizeStream(stream);

    aclrtMemcpy(dstHost, dstByteSize, dstDevice, dstByteSize, ACL_MEMCPY_DEVICE_TO_HOST);



    WriteFile(GetGoldenDir() + "/output.bin", dstHost, dstByteSize);



    aclrtFree(dstDevice);

    aclrtFree(srcDevice);



    aclrtFreeHost(dstHost);

    aclrtFreeHost(srcHost);



    aclrtDestroyStream(stream);

    aclrtResetDevice(0);

    aclFinalize();



    std::vector<T> golden(oRow * oCol);

    std::vector<T> devFinal(oRow * oCol);

    size_t goldenReadSize = 0;

    size_t outputReadSize = 0;

    ReadFile(GetGoldenDir() + "/golden.bin", goldenReadSize, golden.data(), dstByteSize);

    ReadFile(GetGoldenDir() + "/output.bin", outputReadSize, devFinal.data(), dstByteSize);



    return ResultCmp<T>(golden, devFinal, 0.001f);

}



TEST_F(TPOWSTest, case1)

{

    bool ret = TPowSTestFramework<1, float, 32, 64>();

    EXPECT_TRUE(ret);

}



TEST_F(TPOWSTest, case2)

{

    bool ret = TPowSTestFramework<2, aclFloat16, 63, 64>();

    EXPECT_TRUE(ret);

}



TEST_F(TPOWSTest, case3)

{

    bool ret = TPowSTestFramework<3, int32_t, 31, 128>();

    EXPECT_TRUE(ret);

}



TEST_F(TPOWSTest, case4)

{

    bool ret = TPowSTestFramework<4, int16_t, 15, 192>();

    EXPECT_TRUE(ret);

}



TEST_F(TPOWSTest, case5)

{

    bool ret = TPowSTestFramework<5, float, 7, 448>();

    EXPECT_TRUE(ret);

}



TEST_F(TPOWSTest, case6)

{

    bool ret = TPowSTestFramework<6, float, 256, 16>();

    EXPECT_TRUE(ret);

}



TEST_F(TPOWSTest, case7)

{

    bool ret = TPowSTestFramework<7, float, 16, 16, 32, 32, 64, 64>();

    EXPECT_TRUE(ret);

}