/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of 
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, 
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
* \file test_isfinite.cpp
* \brief
*/

#include <cmath>
#include "gtest/gtest.h"
#include "kernel_operator.h"
#include "tikicpulib.h"

#include "test_api_utils.h"
#include "utils.h"
#include "cast.h"

using namespace AscendC;


constexpr int gen_index_two = 2;
constexpr int gen_index_three = 3;
constexpr int gen_index_five = 5;
constexpr int gen_index_div = 1000;
constexpr float gen_float_suffix = 0.12;

template<typename InT, typename OutT>
void CastExtendCalc(InT *x, OutT *y, int size) {
  TPipe tpipe;
  TBuf<TPosition::VECCALC> xbuf, ybuf, tmp;
  tpipe.InitBuffer(xbuf, sizeof(InT) * size);
  tpipe.InitBuffer(ybuf, sizeof(OutT) * size);
  tpipe.InitBuffer(tmp, 8192);

  auto l_x = xbuf.Get<InT>();
  auto l_y = ybuf.Get<OutT>();
  auto l_tmp = tmp.Get<uint8_t>();

  GmToUb(l_x, x, size);
  GmToUb(l_y, y, size);

  CastExtend(l_y, l_x, l_tmp, size);

  UbToGm(y, l_y, size);
}

template<typename InT, typename OutT>
void CastExtendTest(int size, std::function<OutT(int index, InT src)> expectGen,
  std::function<InT(int index)> srcGen, std::function<bool(OutT a, OutT b)> compare
) {
  // 构造测试输入和预期结果
  auto *x = (InT*)AscendC::GmAlloc(sizeof(InT) * size);
  auto *y = (OutT*)AscendC::GmAlloc(sizeof(OutT) * size);

  OutT expect[size];

  for (int i = 0; i < size; i++) {
    auto srcGenValue = srcGen(i);
    auto expectGenValue = expectGen(i, x[i]);
    x[i] = srcGenValue;
    expect[i] = expectGenValue;
  }

  // 构造Api调用函数
  auto kernel = [](int size, InT *x, OutT *y) {
    CastExtendCalc<InT, OutT>(x, y, size);
  };

  // 调用kernel
  AscendC::SetKernelMode(KernelMode::AIV_MODE);
  ICPU_RUN_KF(kernel, 1, size, x, y);

  // 验证结果
  int diff_count = 0;
  for (int i = 0; i < size; i++) {
    auto xValue = x[i];
    auto yValue = y[i];
    auto expectValue = expect[i];
    if (!compare(yValue, expectValue)) {
      diff_count++;
    }
  }

  EXPECT_EQ(diff_count, 0);
}

constexpr uint8_t CastGen(const int index) {
  return index % gen_index_two;
}

constexpr bool CastCompare(const float a, const float b) {
  auto diff = a - b;
  if (diff < -1e-5 || diff > 1e-5) {
    return false;
  }
  return true;
}

constexpr float ExpectGenU82F32(const int index, const float x) {
  return (float)(index % gen_index_two);
}

class TestApiCastUint8 : public ::testing::Test, public testing::WithParamInterface<size_t> {
};

TEST_P(TestApiCastUint8, Calc) {
  const int size = this->GetParam();
  CastExtendTest<uint8_t, float>(size, ExpectGenU82F32, CastGen, CastCompare);
}

INSTANTIATE_TEST_SUITE_P(DiffLength, TestApiCastUint8,
                         ::testing::Values(
                           /* 1 block */ ONE_BLK_SIZE / sizeof(float),
                           /* 1 repeat */ ONE_REPEAT_BYTE_SIZE / sizeof(float),
                           /* max repeat */ MAX_REPEAT_NUM *ONE_REPEAT_BYTE_SIZE / sizeof(float),
                           /* less than 1 block */ (ONE_BLK_SIZE - sizeof(float)) / sizeof(float),
                           /* less than 1 repeat */ (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) / sizeof(float),
                           /* less than max repeat */ (MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE / sizeof(float),
                           /* mix block, repeat, max repeat*/
                           ((MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE + (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) +
                             (ONE_BLK_SIZE - sizeof(float))) / sizeof(float)));

constexpr int64_t CastGenS642U8(const int index) {
  return index % gen_index_div;
}

constexpr bool CastCompareS642U8(const uint8_t a, const uint8_t b) {
  return (a - b == 0);
}

constexpr uint8_t ExpectGenS642U8(const int index, const uint8_t x) {
  auto tmp = index % gen_index_div;
  if (tmp > 255) {
    return 255;
  } else {
    return tmp;
  }
}

class TestApiCastS642U8 : public ::testing::Test, public testing::WithParamInterface<size_t> {
};

TEST_P(TestApiCastS642U8, Calc) {
  const int size = this->GetParam();
  CastExtendTest<int64_t, uint8_t>(size, ExpectGenS642U8, CastGenS642U8, CastCompareS642U8);
}

INSTANTIATE_TEST_SUITE_P(DiffLength, TestApiCastS642U8,
                         ::testing::Values(
                           /* 1 block */ ONE_BLK_SIZE / sizeof(int64_t),
                           /* 1 repeat */ ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* max repeat */ MAX_REPEAT_NUM *ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* less than 1 block */ (ONE_BLK_SIZE - sizeof(int64_t)) / sizeof(int64_t),
                           /* less than 1 repeat */ (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) / sizeof(int64_t),
                           /* less than max repeat */ (MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* mix block, repeat, max repeat*/
                           ((MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE + (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) +
                             (ONE_BLK_SIZE - sizeof(int64_t))) / sizeof(int64_t)));

constexpr int64_t CastGenS642F16(const int index) {
  return index % gen_index_div;
}

constexpr bool CastCompareS642F16(const double a, const double b) {
  auto diff = a - b;
  if (diff < -1e-5 || diff > 1e-5) {
    return false;
  }
  return true;
}

constexpr double ExpectGenS642F16(const int index, const uint8_t x) {
  return double(index % gen_index_div);
}

class TestApiCastS642F16 : public ::testing::Test, public testing::WithParamInterface<size_t> {
};

TEST_P(TestApiCastS642F16, Calc) {
  const int size = this->GetParam();
  CastExtendTest<int64_t, half>(size, ExpectGenS642F16, CastGenS642F16, CastCompareS642F16);
}

INSTANTIATE_TEST_SUITE_P(DiffLength, TestApiCastS642F16,
                         ::testing::Values(
                           /* 1 block */ ONE_BLK_SIZE / sizeof(int64_t),
                           /* 1 repeat */ ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* max repeat */ MAX_REPEAT_NUM *ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* less than 1 block */ (ONE_BLK_SIZE - sizeof(int64_t)) / sizeof(int64_t),
                           /* less than 1 repeat */ (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) / sizeof(int64_t),
                           /* less than max repeat */ (MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* mix block, repeat, max repeat*/
                           ((MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE + (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) +
                             (ONE_BLK_SIZE - sizeof(int64_t))) / sizeof(int64_t)));

constexpr float CastGenF162S64(const int index) {
  if (index % gen_index_two == 0) {
    return index / gen_index_div + gen_float_suffix;
  }
  if (index % gen_index_three == 0) {
    return (index / gen_index_div + gen_float_suffix) * -1;
  }
  return index / gen_index_div;
}

constexpr bool CastCompareF162S64(const int64_t a, const int64_t b) {
  return (a == b);
}

constexpr int64_t ExpectGenF162S64(const int index, const float x) {
  if (index % gen_index_two == 0) {
    return int64_t(index / gen_index_div + gen_float_suffix);
  }
  if (index % gen_index_three == 0) {
    return int64_t((index / gen_index_div + gen_float_suffix) * - 1);
  }
  return int64_t(index / gen_index_div);
}

class TestApiCastF162S64 : public ::testing::Test, public testing::WithParamInterface<size_t> {
};

TEST_P(TestApiCastF162S64, Calc) {
  const int size = this->GetParam();
  CastExtendTest<half, int64_t>(size, ExpectGenF162S64, CastGenF162S64, CastCompareF162S64);
}

INSTANTIATE_TEST_SUITE_P(DiffLength, TestApiCastF162S64,
                         ::testing::Values(
                           /* 1 block */ ONE_BLK_SIZE / sizeof(int64_t),
                           /* 1 repeat */ ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* max repeat */ MAX_REPEAT_NUM *ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* less than 1 block */ (ONE_BLK_SIZE - sizeof(int64_t)) / sizeof(int64_t),
                           /* less than 1 repeat */ (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) / sizeof(int64_t),
                           /* less than max repeat */ (MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* mix block, repeat, max repeat*/
                           ((MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE + (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) +
                             (ONE_BLK_SIZE - sizeof(int64_t))) / sizeof(int64_t)));

constexpr uint32_t CastGenUint2Int(const int index) {
  if (index % gen_index_two == 0) {
    return (uint32_t)0xFFFFFF85;
  }
  if (index % gen_index_three == 0) {
    return uint32_t((index / gen_index_div) * 100 -1);
  }
  return uint32_t(index / gen_index_div);
}

constexpr bool CastCompareUint2Int(const int32_t a, const int32_t b) {
  return (a == b);
}

constexpr int32_t ExpectGenUint2Int(const int index, const uint32_t x) {
  if (index % gen_index_two == 0) {
    return int32_t(0xFFFFFF85);
  }
  if (index % gen_index_three == 0) {
    return int32_t((index / gen_index_div) * 100 -1);
  }
  return int32_t(index / gen_index_div);
}

class TestApiCastUint2Int : public ::testing::Test, public testing::WithParamInterface<size_t> {
};

TEST_P(TestApiCastUint2Int, Calc) {
  const int size = this->GetParam();
  CastExtendTest<uint32_t, int32_t>(size, ExpectGenUint2Int, CastGenUint2Int, CastCompareUint2Int);
}

INSTANTIATE_TEST_SUITE_P(DiffLength, TestApiCastUint2Int,
                         ::testing::Values(
                           /* 1 block */ ONE_BLK_SIZE / sizeof(int32_t),
                           /* 1 repeat */ ONE_REPEAT_BYTE_SIZE / sizeof(int32_t),
                           /* max repeat */ MAX_REPEAT_NUM *ONE_REPEAT_BYTE_SIZE / sizeof(int32_t),
                           /* less than 1 block */ (ONE_BLK_SIZE - sizeof(int32_t)) / sizeof(int32_t),
                           /* less than 1 repeat */ (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) / sizeof(int32_t),
                           /* less than max repeat */ (MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE / sizeof(int32_t),
                           /* mix block, repeat, max repeat*/
                           ((MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE + (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) +
                             (ONE_BLK_SIZE - sizeof(int32_t))) / sizeof(int32_t)));

constexpr int64_t CastGenS642U64(const int index) {
  if (index % gen_index_two == 0) {
    return (int64_t)0xFFFFFF85;
  }
  if (index % gen_index_three == 0) {
    return int64_t((index / gen_index_div) * 100 -1);
  }
  return int64_t(index / gen_index_div);
}

constexpr bool CastCompareS642U64(const int32_t a, const int32_t b) {
  return (a == b);
}

constexpr uint64_t ExpectGenS642U64(const int index, const int64_t x) {
  if (index % gen_index_two == 0) {
    return uint64_t(0xFFFFFF85);
  }
  if (index % gen_index_three == 0) {
    return uint64_t((index / gen_index_div) * 100 -1);
  }
  return uint64_t(index / gen_index_div);
}

class TestApiCastS642U64 : public ::testing::Test, public testing::WithParamInterface<size_t> {
};

TEST_P(TestApiCastS642U64, Calc) {
  const int size = this->GetParam();
  CastExtendTest<int64_t, uint64_t>(size, ExpectGenS642U64, CastGenS642U64, CastCompareS642U64);
}

INSTANTIATE_TEST_SUITE_P(DiffLength, TestApiCastS642U64,
                         ::testing::Values(
                           /* 1 block */ ONE_BLK_SIZE / sizeof(int64_t),
                           /* 1 repeat */ ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* max repeat */ MAX_REPEAT_NUM *ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* less than 1 block */ (ONE_BLK_SIZE - sizeof(int64_t)) / sizeof(int64_t),
                           /* less than 1 repeat */ (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) / sizeof(int64_t),
                           /* less than max repeat */ (MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE / sizeof(int64_t),
                           /* mix block, repeat, max repeat*/
                           ((MAX_REPEAT_NUM - 1) * ONE_REPEAT_BYTE_SIZE + (ONE_REPEAT_BYTE_SIZE - ONE_BLK_SIZE) +
                             (ONE_BLK_SIZE - sizeof(int64_t))) / sizeof(int64_t)));