/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file conv3d_tiling_util.cpp
 * \brief
 */

#include "conv3d_tiling_util.h"

namespace Conv3dTilingApi {

int64_t LCM(int64_t numL, int64_t numR)
{
    if (numR == 0 || numL == 0) {
        return 1;
    }
    int64_t product = numL * numR;
    while (numL % numR != 0) {
        int64_t tmp = numL % numR;
        numL = numR;
        numR = tmp;
    }

    return product / numR;
}

uint64_t CeilDiv(uint64_t a, uint64_t b)
{
    if (b == 0) {
        return b;
    }
    return (a + b - 1) / b;
}

uint64_t AlignB(uint64_t a, uint64_t b)
{
    if (b == 0) {
        return 0;
    }
    return ((a + b - 1) / b) * b;
}

uint64_t Gcd(uint64_t a, uint64_t b)
{
    if (b == 0) {
        return a;
    }

    uint64_t c;
    if (a < b) {
        c = a;
        a = b;
        b = c;
    }

    while (a % b != 0) {
        c = a % b;
        a = b;
        b = c;
    }

    return b;
}

void CalcCommFactorWithPowerOfTwo(const uint64_t num, const uint64_t numMax, std::vector<uint64_t>& resList)
{
    uint64_t sqrtMax = static_cast<uint64_t>(sqrt(num));
    for (uint64_t i = 1; i <= sqrtMax; ++i) {
        if (num % i == 0) {
            if (i <= numMax) {
                resList.emplace_back(i);
            }
            uint64_t right = num / i;
            if (right != i && right <= numMax) {
                resList.emplace_back(right);
            }
        }
    }

    for (uint64_t i = CONST_VALUE_2; i <= std::min(num, numMax); i *= CONST_VALUE_2) {
        if (std::find(resList.begin(), resList.end(), i) == resList.end()) {
            resList.emplace_back(i);
        }
    }
    sort(resList.begin(), resList.end());
}

void CalcCommFactor(const uint64_t num, const uint64_t numMax, std::vector<uint64_t>& resList)
{
    uint64_t sqrtMax = static_cast<uint64_t>(sqrt(num));
    for (uint64_t i = 1; i <= sqrtMax; ++i) {
        if (num % i == 0) {
            if (i <= numMax) {
                resList.emplace_back(i);
            }
            uint64_t right = num / i;
            if (right != i && right <= numMax) {
                resList.emplace_back(right);
            }
        }
    }
    sort(resList.begin(), resList.end());
}

void CalcFactorPointWise(uint64_t numMax, std::vector<uint64_t>& resList)
{
    numMax = numMax < CONST_VALUE_2 ? CONST_VALUE_2 : numMax;
    for (uint64_t i = CONST_VALUE_2; i <= numMax; i = i + CONST_VALUE_2) {
        resList.emplace_back(i);
    }
    sort(resList.begin(), resList.end());
}

void VectorElementMultip(std::vector<uint64_t>& range, const uint64_t value)
{
    for (auto& factor : range) {
        factor *= value;
    }
}

bool IsArrayEqual(
    const std::vector<ConvCommonApi::ConvDtype>& arr1, const std::vector<ConvCommonApi::ConvDtype>& arr2, uint32_t size)
{
    if (arr1.size() < size || arr2.size() < size) {
        return false;
    }
    for (size_t i = 0; i < size; i++) {
        if (arr1[i] != arr2[i]) {
            return false;
        }
    }
    return true;
}

uint64_t InferHiL1(uint64_t inputHoL1, uint64_t hi, uint64_t singlekH, uint32_t dilationH, uint32_t strideH)
{
    uint64_t khDilated = (singlekH - 1) * dilationH + 1;
    uint64_t tmpHiL1 = (inputHoL1 - 1) * strideH + khDilated;
    if (tmpHiL1 > hi) {
        tmpHiL1 = hi;
    }

    return tmpHiL1;
}
} // namespace Conv3dTilingApi