* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tensor_offset.h
* \brief
*/
#pragma once
#include <vector>
#include <set>
#include <string>
#include <memory>
#include <unordered_set>
#include <functional>
#include "tilefwk/error.h"
#include "symbolic_scalar.h"
namespace npu::tile_fwk {
class TensorOffset {
public:
TensorOffset(const std::vector<int64_t>& offset, const std::vector<SymbolicScalar>& dynOffset)
: offset_(offset), dynOffset_(dynOffset)
{}
const std::vector<int64_t>& GetOffset() const { return offset_; }
const std::vector<SymbolicScalar>& GetDynOffset() const { return dynOffset_; }
template <typename Tret, typename Tlhs, typename Trhs>
static std::vector<Tret> AddRaw(const std::vector<Tlhs>& lhs, const std::vector<Trhs>& rhs)
{
FE_ASSERT(FeError::INVALID_VAL, lhs.size() == rhs.size())
<< "lhs.size():" << lhs.size() << ", rhs.size():" << rhs.size();
std::vector<Tret> ret(lhs.size());
for (size_t k = 0; k < lhs.size(); k++) {
ret[k] = lhs[k] + rhs[k];
}
return ret;
}
static std::vector<int64_t> Add(const std::vector<int64_t>& lhs, const std::vector<int64_t>& rhs)
{
return AddRaw<int64_t, int64_t, int64_t>(lhs, rhs);
}
static std::vector<SymbolicScalar> Add(const std::vector<SymbolicScalar>& lhs, const std::vector<int64_t>& rhs)
{
return AddRaw<SymbolicScalar, SymbolicScalar, int64_t>(lhs, rhs);
}
static std::vector<SymbolicScalar> Add(const std::vector<int64_t>& lhs, const std::vector<SymbolicScalar>& rhs)
{
return AddRaw<SymbolicScalar, int64_t, SymbolicScalar>(lhs, rhs);
}
static std::vector<SymbolicScalar> Add(
const std::vector<SymbolicScalar>& lhs, const std::vector<SymbolicScalar>& rhs)
{
return AddRaw<SymbolicScalar, SymbolicScalar, SymbolicScalar>(lhs, rhs);
}
static std::vector<int64_t> Zero(const std::vector<int64_t>& off) { return std::vector<int64_t>(off.size(), 0); }
static bool IsZero(const std::vector<int64_t>& off) { return Zero(off) == off; }
static std::pair<std::vector<int64_t>, std::vector<SymbolicScalar>> Add(
const std::vector<int64_t>& lhs, const std::vector<SymbolicScalar>& lhsDyn, const std::vector<int64_t>& rhs,
const std::vector<SymbolicScalar>& rhsDyn)
{
FE_ASSERT(FeError::INVALID_VAL, lhs.size() == rhs.size())
<< "lhs.size():" << lhs.size() << ", rhs.size():" << rhs.size();
std::vector<int64_t> ret = Add(lhs, rhs);
std::vector<SymbolicScalar> retDyn;
if (lhsDyn.size() != 0 && rhsDyn.size() != 0) {
retDyn = Add(lhsDyn, rhsDyn);
} else if (lhsDyn.size() != 0) {
retDyn = Add(lhsDyn, rhs);
} else if (rhsDyn.size() != 0) {
retDyn = Add(lhs, rhsDyn);
}
return std::make_pair(ret, retDyn);
}
static std::vector<int64_t> Sub(const std::vector<int64_t>& lhs, const std::vector<int64_t>& rhs)
{
FE_ASSERT(FeError::INVALID_VAL, lhs.size() == rhs.size())
<< "lhs.size():" << lhs.size() << ", rhs.size():" << rhs.size();
std::vector<int64_t> result(lhs.size());
std::transform(lhs.begin(), lhs.end(), rhs.begin(), result.begin(), [](int a, int b) { return a - b; });
return result;
}
static std::vector<SymbolicScalar> Sub(const std::vector<SymbolicScalar>& lhs, const std::vector<int64_t>& rhs)
{
if (lhs.size() == 0) {
return {};
}
FE_ASSERT(FeError::INVALID_VAL, lhs.size() == rhs.size())
<< "lhs.size():" << lhs.size() << ", rhs.size():" << rhs.size();
std::vector<SymbolicScalar> result(lhs.size());
std::transform(
lhs.begin(), lhs.end(), rhs.begin(), result.begin(), [](const SymbolicScalar& a, int b) { return a - b; });
return result;
}
public:
const std::vector<int64_t>& offset_;
const std::vector<SymbolicScalar>& dynOffset_;
};
}