* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file logical_tensor.h
* \brief
*/
#pragma once
#include <vector>
#include <set>
#include <map>
#include <string>
#include <memory>
#include <unordered_set>
#include <functional>
#include "tilefwk/data_type.h"
#include "interface/inner/pre_def.h"
#include "raw_tensor.h"
#include "interface/operation/attr_holder.h"
#include "symbolic_scalar.h"
#include "tensor_offset.h"
#include <nlohmann/json.hpp>
#include "storage.h"
#include "ir/expr.h"
using Json = nlohmann::json;
using namespace pypto;
namespace npu::tile_fwk {
class TileRange {
public:
size_t start;
size_t end;
int lifeStart = -1;
int lifeEnd = -1;
int memId = -1;
explicit TileRange(size_t s = 0, size_t e = 0, int id = -1) : start(s), end(e), memId(id) {}
size_t Size() const { return end - start; }
bool IsEmpty() const { return end <= start; }
};
class LogicalTensor : public std::enable_shared_from_this<LogicalTensor>, public AttrHolder, public ir::Var {
public:
std::shared_ptr<RawTensor> tensor;
Offset offset;
Shape shape;
Shape oriShape;
std::vector<SymbolicScalar> dynOffset_;
std::vector<SymbolicScalar> dynValidShape_;
Shape storageShape;
std::shared_ptr<Storage> storage_ = nullptr;
uint64_t storageOffset_ = 0;
int magic;
TileRange memoryrange;
LogicalTensor(
Function& function, DataType t, Shape tshape, TileOpFormat tformat = TileOpFormat::TILEOP_ND,
std::string tname = "");
LogicalTensor(
Function& function, DataType t, Shape tshape, std::vector<SymbolicScalar> tValidShape,
TileOpFormat tformat = TileOpFormat::TILEOP_ND, std::string tname = "");
LogicalTensor(Function& function, std::shared_ptr<RawTensor> rawTensor, Offset toffset, Shape tshape);
LogicalTensor(
Function& function, std::shared_ptr<RawTensor> rawTensor, Offset toffset, Shape tshape,
std::vector<SymbolicScalar> tValidShape);
LogicalTensor(LogicalTensor&&) = delete;
LogicalTensor(const LogicalTensor&) = delete;
LogicalTensor& operator=(LogicalTensor&&) = delete;
LogicalTensor& operator=(const LogicalTensor&) = delete;
std::shared_ptr<LogicalTensor> Clone(Function& dstFunc, bool create = false) const;
Function& BelongFunction() { return *function_; }
const Function& BelongFunction() const { return *function_; }
std::string DumpType() const;
Json DumpJson(Function& func, bool dumpRawTensor = true) const;
static std::shared_ptr<LogicalTensor> LoadJson(
Function& function, const std::unordered_map<int, std::shared_ptr<RawTensor>>& rawTensorDict,
const Json& tensorDump);
std::string DumpSSA(bool showFrom = true, bool showMem = false, bool showType = true) const;
std::string Dump(bool showFrom = true, bool showMem = false) const;
std::shared_ptr<LogicalTensor> View(Function& function, const Shape& newShape, const Offset& newOffset) const;
DataType Datatype() const;
std::string Symbol() const;
TileOpFormat Format() const { return tensor->format; }
MemoryType GetMemoryTypeOriginal() const;
MemoryType GetMemoryTypeToBe() const;
void CopyMemoryType(const std::shared_ptr<LogicalTensor>& other);
void SetMemoryTypeBoth(MemoryType t, bool force = false);
void SetMemoryTypeOriginal(MemoryType t, bool force = false);
void SetMemoryTypeToBe(MemoryType t);
bool MemoryConflict() const;
size_t MemorySize() const;
bool IsDummy() const;
void SetIsDummy(bool dummy = true);
bool Overlap(const std::shared_ptr<LogicalTensor>& other) const;
int GetMagic() const { return magic; }
void SetMagic(int m) { magic = m; }
int GetRawMagic() const { return tensor->GetRawMagic(); }
std::shared_ptr<RawTensor> GetRawTensor() const { return tensor; }
const Offset& GetOffset() const { return offset; }
const Shape& GetShape() const { return shape; }
void UpdateOffset(const Offset& newOffset)
{
FE_ASSERT(FeError::INVALID_VAL, newOffset.size() == shape.size())
<< "newOffset.size(): " << newOffset.size() << ", shape.size(): " << shape.size();
offset = newOffset;
}
void UpdateOffset(const TensorOffset& tensorOffset)
{
this->offset = tensorOffset.GetOffset();
this->dynOffset_ = tensorOffset.GetDynOffset();
}
const TensorOffset GetTensorOffset() const { return TensorOffset(offset, dynOffset_); }
void UpdateDynValidShape(const std::vector<SymbolicScalar>& dynValidShape) { dynValidShape_ = dynValidShape; }
struct CompareOp {
bool operator()(const Operation* a, const Operation* b) const;
};
auto& GetProducers() { return producers_; }
auto& GetConsumers() { return consumers_; }
auto& GetDependOps() { return dependOps_; }
const auto& GetProducers() const { return producers_; }
const auto& GetConsumers() const { return consumers_; }
const auto& GetDependOps() const { return dependOps_; }
bool HasProducer(Operation* operation) const { return producers_.count(operation) > 0; }
bool HasConsumer(Operation* operation) const { return consumers_.count(operation) > 0; }
bool HasDependOp(Operation* operation) const { return dependOps_.count(operation) > 0; }
bool HasProducer(Operation& operation) const { return HasProducer(&operation); }
bool HasConsumer(Operation& operation) const { return HasConsumer(&operation); }
bool HasDependOp(Operation& operation) const { return HasDependOp(&operation); }
void AddProducer(Operation* operation) { producers_.emplace(operation); }
void AddConsumer(Operation* operation) { consumers_.emplace(operation); }
void AddDependOp(Operation* operation) { dependOps_.emplace(operation); }
void RemoveProducer(Operation* operation) { producers_.erase(operation); }
void RemoveConsumer(Operation* operation) { consumers_.erase(operation); }
void RemoveDependOp(Operation* operation) { dependOps_.erase(operation); }
void AddProducer(Operation& operation) { AddProducer(&operation); }
void AddConsumer(Operation& operation) { AddConsumer(&operation); }
void AddDependOp(Operation& operation) { AddDependOp(&operation); }
void RemoveProducer(Operation& operation) { RemoveProducer(&operation); }
void RemoveConsumer(Operation& operation) { RemoveConsumer(&operation); }
void RemoveDependOp(Operation& operation) { RemoveDependOp(&operation); }
void ClearAllProducers() { producers_.clear(); }
void operator<<(LogicalTensor& right);
int64_t GetDataSize() const;
bool IsOffsetAllZero() const
{
return std::all_of(offset.begin(), offset.end(), [](int value) { return value == 0; });
}
const std::vector<SymbolicScalar>& GetDynOffset() const { return dynOffset_; }
const std::vector<SymbolicScalar>& GetDynValidShape() const { return dynValidShape_; }
std::vector<SymbolicScalar>& GetDynValidShape() { return dynValidShape_; }
void SetCachePolicy(CachePolicy policy, bool value)
{
if (tensor != nullptr) {
tensor->SetCachePolicy(policy, value);
}
}
bool GetCachePolicy(CachePolicy policy) const
{
if (tensor != nullptr) {
return tensor->GetCachePolicy(policy);
}
return false;
}
bool IsGetTensorDataOutcast();
private:
MemoryType memoryTypeOriginal_{MemoryType::MEM_UNKNOWN};
MemoryType memoryTypeToBe_{MemoryType::MEM_UNKNOWN};
int readyTime_{INVALID_TIME};
int remainingTime_{INVALID_TIME};
Function* function_;
std::set<Operation*, CompareOp> producers_;
std::set<Operation*, CompareOp> consumers_;
std::set<Operation*, CompareOp> dependOps_;
};
enum EmuOpcode {
EMUOP_TENSOR_EXTRACT,
EMUOP_TENSOR_INSERT,
EMUOP_TENSOR_GETDATA_DEPEND,
EMUOP_TENSOR_GETDATA_IMPORT,
EMUOP_TENSOR_SETDATA,
};
bool CheckEmuOpcode(const Operation* op, EmuOpcode opcode);
void SetEmuOpcode(Operation* op, EmuOpcode opcode);
bool TypeEqual(const LogicalTensorPtr a, const LogicalTensorPtr b);
Tensor TensorExtract(const Tensor& src, const std::vector<SymbolicScalar>& offset);
void TensorInsert(const Tensor& src, const std::vector<SymbolicScalar>& offset, Tensor& dst);
SymbolicScalar GetViewValidShapeDim(
const SymbolicScalar& validShapeDim, const SymbolicScalar& viewOffsetDim, const SymbolicScalar& viewShapeDim);
std::vector<SymbolicScalar> GetViewValidShape(
const std::vector<SymbolicScalar>& validShape, const Offset& viewOffset,
const std::vector<SymbolicScalar>& viewDynOffset, const Shape& viewShape);
std::map<int, std::vector<RawSymbolicScalarPtr>> GetTensorDataDict(const SymbolicScalar& dimOffset);
std::map<int, std::vector<RawSymbolicScalarPtr>> GetTensorDataDict(const std::vector<SymbolicScalar>& offset);
std::map<int, std::vector<RawSymbolicScalarPtr>> GetTensorDataDict(
const std::vector<std::reference_wrapper<SymbolicScalar>>& offset);
struct GetTensorDataIODesc {
int ioType{-1};
int ioTypeIndex{-1};
SymbolicScalar address;
GetTensorDataIODesc() = default;
GetTensorDataIODesc(int ioType_, int ioTypeIndex_, SymbolicScalar address_)
: ioType(ioType_), ioTypeIndex(ioTypeIndex_), address(address_)
{}
};
int GetTensorDataGetIndex(const Operation* op);
void GetTensorDataSetIndex(Operation* op, int index);
int GetTensorDataGetCoaIndex(const Operation* op);
void GetTensorDataSetCoaIndex(Operation* op, int index);
struct GetTensorDataIODescDict : std::unordered_map<int, GetTensorDataIODesc> {
std::string Dump() const;
};
constexpr int GET_TENSOR_DATA_OPERAND_INDEX_CALLEE = 0;
constexpr int GET_TENSOR_DATA_OPERAND_INDEX_INDEX = 1;
constexpr int GET_TENSOR_DATA_OPERAND_INDEX_IOTYPE = 2;
constexpr int GET_TENSOR_DATA_OPERAND_INDEX_IOTYPE_INDEX = 3;
constexpr int GET_TENSOR_DATA_OPERAND_INDEX_ADDRESS = 4;
constexpr int GET_TENSOR_DATA_OPERAND_IOTYPE_INCAST = 0;
constexpr int GET_TENSOR_DATA_OPERAND_IOTYPE_OUTCAST = 1;
SymbolicScalar GetTensorDataFillIO(const GetTensorDataIODescDict& iodescDict, const SymbolicScalar& dimOffset);
SymbolicScalar UpdateGetTensorDataIOIndex(size_t currOutcastIdx, size_t newOutcastIdx, const SymbolicScalar& scalar);
std::set<std::pair<int, int>> GetTensorDataUsage(const std::vector<std::reference_wrapper<SymbolicScalar>>& scalars);
constexpr int RUNTIME_GET_PARAM_OFFSET_OPERAND_INDEX_DIM_SIZE_INDEX = 1;
constexpr int RUNTIME_GET_PARAM_OFFSET_OPERAND_INDEX_COA_INDEX = 2;
constexpr int RUNTIME_GET_PARAM_OFFSET_OPERAND_INDEX_DIM_INDEX = 3;
}