/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "tensor_shape_impl.h"

#include "log.h"

namespace aicpu {
/*
 * get dims value of tensor shape.
 */
std::vector<int64_t> TensorShapeImpl::GetDimSizes() const
{
    std::vector<int64_t> ret;
    for (int32_t i = 0; i < tensor_shape_->dim_size(); i++) {
        ret.emplace_back(tensor_shape_->dim(i).size());
    }
    return ret;
}

/*
 * set dims value to tensor shape.
 */
void TensorShapeImpl::SetDimSizes(const std::vector<int64_t>& dims)
{
    tensor_shape_->clear_dim();
    for (size_t i = 0; i < dims.size(); ++i) {
        aicpuops::TensorShape_Dim* aicpu_dims = tensor_shape_->add_dim();
        KERNEL_CHECK_NULLPTR_VOID(aicpu_dims, "Protobuf add dim is null")
        aicpu_dims->set_size(dims[i]);
    }
}

/*
 * get format value of tensor shape.
 */
Format TensorShapeImpl::GetFormat() const { return static_cast<Format>(tensor_shape_->data_format()); }

/*
 * set format value to tensor shape.
 */
void TensorShapeImpl::SetFormat(Format format) { tensor_shape_->set_data_format(format); }

/*
 * get unknown rank value of tensor shape.
 */
bool TensorShapeImpl::GetUnknownRank() const { return tensor_shape_->unknown_rank(); }

/*
 * set unknown rank value to tensor shape.
 */
void TensorShapeImpl::SetUnknownRank(bool unknown_rank) { tensor_shape_->set_unknown_rank(unknown_rank); }

/*
 * get dims size of tensor shape.
 */
int32_t TensorShapeImpl::GetDims() const { return tensor_shape_->dim_size(); }

/*
 * get dim value of tensor shape index dim.
 */
int64_t TensorShapeImpl::GetDimSize(int32_t index) const
{
    if ((index >= GetDims()) || (index < 0)) {
        KERNEL_LOG_ERROR(
            "Dim index[%d] must be not less than 0 and not greater than dims "
            "size[%d]",
            index, GetDims());
        return 0;
    }

    return tensor_shape_->dim(index).size();
}

/*
 * get data elements number.
 */
int64_t TensorShapeImpl::NumElements() const
{
    int64_t num_elements = 1;
    for (int32_t i = 0; i < tensor_shape_->dim_size(); i++) {
        int64_t dim_size = tensor_shape_->dim(i).size();
        if (dim_size < 0) {
            return -1;
        }

        KERNEL_CHECK_ASSIGN_64S_MULTI(num_elements, dim_size, num_elements, -1);
    }
    return num_elements;
}

/*
 * get tensor proto.
 * @return shared_ptr<TensorShapeProto>:tensor shape proto ptr
 */

aicpuops::TensorShape* TensorShapeImpl::GetProto() const { return tensor_shape_.get(); }
} // namespace aicpu