* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file tensor_slot.h
* \brief
*/
#pragma once
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include "tilefwk/tensor.h"
#include "interface/utils/common.h"
#include "interface/tensor/logical_tensor.h"
#include "interface/tensor/runtime_slot.h"
namespace npu::tile_fwk {
enum class SlotProperty : uint32_t {
NONE,
OUTPUT,
ASSEMBLE_DST,
SHMEM_TENSOR,
};
struct TensorSlot {
public:
TensorSlot() {}
TensorSlot(int64_t id, const void* slot) : id_(id), slot_(slot) {}
const void* GetSlot() const { return slot_; }
int64_t GetId() const { return id_; }
std::string GetSymbolName() const;
std::shared_ptr<LogicalTensor> GetSlotValue() const;
void SetSlotValue(const std::shared_ptr<LogicalTensor>& value) const;
std::string Dump() const;
std::string DumpHead(const std::string& name) const;
bool operator==(const TensorSlot& oth) const { return id_ == oth.id_; }
static TensorSlot CreateTensor(const Tensor& tensor) { return TensorSlot(tensor.Id(), &tensor); }
private:
int64_t id_{-1};
const void* slot_{nullptr};
};
}
template <>
struct std::hash<npu::tile_fwk::TensorSlot> {
std::size_t operator()(const npu::tile_fwk::TensorSlot& t) const { return std::hash<int64_t>()(t.GetId()); }
};
namespace npu::tile_fwk {
class Function;
struct TensorSlotAccess {
public:
TensorSlotAccess() {}
const std::shared_ptr<LogicalTensor>& GetFirstReadTensor() const { return firstReadTensor_; }
std::shared_ptr<LogicalTensor> GetFirstReadTensor() { return firstReadTensor_; }
const std::shared_ptr<LogicalTensor>& GetLastWriteTensor() const { return lastWriteTensor_; }
std::shared_ptr<LogicalTensor> GetLastWriteTensor() { return lastWriteTensor_; }
void Read(const std::shared_ptr<LogicalTensor>& tensor)
{
if (!written_) {
firstReadTensor_ = tensor;
}
}
void Write(const std::shared_ptr<LogicalTensor>& tensor)
{
written_ = true;
lastWriteTensor_ = tensor;
}
std::string Dump() const
{
std::ostringstream oss;
oss << "<" << (written_ ? 'W' : ' ') << ","
<< (firstReadTensor_ ? firstReadTensor_->Dump() : std::string("noread")) << ","
<< (lastWriteTensor_ ? lastWriteTensor_->Dump() : std::string("nowrite")) << ">";
return oss.str();
}
private:
bool written_{false};
std::shared_ptr<LogicalTensor> firstReadTensor_;
std::shared_ptr<LogicalTensor> lastWriteTensor_;
};
struct IncastOutcastSlot {
*
* Tensor a("a"), b("b");
* Tensor v0, v1;
* FUNCTION("A") {
* Tensor t = a + b;
* v0 = t; // t outcast via v0
* v1 = t; // t outcast via v1
* }
*
* Tensor x0;
* FUNCTION("B") {
* // t incast via both v0 and v1
* x0 = v0 + v1;
* }
*/
std::vector<std::vector<int>> incastSlot;
std::vector<std::vector<int>> outcastSlot;
std::vector<int> partialUpdateOutcastList;
};
struct TensorSlotScope {
Function* tensorFunc = nullptr;
std::unordered_map<TensorSlot, TensorSlotAccess> accessRecord;
std::unordered_map<std::shared_ptr<LogicalTensor>, std::shared_ptr<LogicalTensor>> incastToInArgumentDict;
std::unordered_map<std::shared_ptr<LogicalTensor>, std::shared_ptr<LogicalTensor>> outcastToOutArgumentDict;
std::vector<std::unordered_set<TensorSlot>> incastReadSlotSet;
std::vector<std::unordered_set<TensorSlot>> outcastWriteSlotSet;
std::vector<std::unordered_set<TensorSlot>> oriIncastReadSlotSet;
std::vector<std::unordered_set<TensorSlot>> oriOutcastWriteSlotSet;
std::unordered_map<LogicalTensorPtr, int> partialUpdateOutcastDict;
std::vector<int> constructAssembleSlotList;
IncastOutcastSlot ioslot;
IncastOutcastSlot originalIocastsSlot;
explicit TensorSlotScope(Function* tfunc) : tensorFunc(tfunc) {}
TensorSlotScope(TensorSlotScope&& scope) = default;
TensorSlotScope& operator=(TensorSlotScope&& scope) = default;
std::unordered_set<TensorSlot> LookupIncastReadFrom(const std::shared_ptr<LogicalTensor>& tensor) const;
std::unordered_set<TensorSlot> LookupOutcastWriteTo(const std::shared_ptr<LogicalTensor>& tensor) const;
std::unordered_set<TensorSlot> LoopupArgSlot(std::shared_ptr<RawTensor> tensor);
void BuildSlotSet();
void BuildIncastOutcastSlot(const std::unordered_map<TensorSlot, int>& slotIndexDict);
std::string Dump() const;
};
struct IncastOutcastLink {
explicit IncastOutcastLink(int slotNum = 0) : totalSlot(slotNum) {}
int totalSlot;
std::unordered_map<Function*, IncastOutcastSlot> ioslotDict;
std::vector<int> inputSlotIndexList;
std::vector<int> outputSlotIndexList;
std::vector<int> assembleSlotIndexList;
std::vector<int> shmemTensorSlotIndexList;
std::vector<int> inplaceSlotIndexList;
std::vector<int> partialUpdateSlotIdexList;
std::vector<RuntimeSlotKindSet> runtimeSlotKindSetList;
void UpdateRuntimeSlotKindSetList()
{
runtimeSlotKindSetList.resize(totalSlot);
for (int inputSlotIndex : inputSlotIndexList) {
runtimeSlotKindSetList[inputSlotIndex].Add(RuntimeSlotKind::INPUT);
}
for (int outputSlotIndex : outputSlotIndexList) {
runtimeSlotKindSetList[outputSlotIndex].Add(RuntimeSlotKind::OUTPUT);
}
for (int assembleSlotIndex : assembleSlotIndexList) {
runtimeSlotKindSetList[assembleSlotIndex].Add(RuntimeSlotKind::ASSEMBLE_OUTCAST);
}
for (int shmemSlotIndex : shmemTensorSlotIndexList) {
runtimeSlotKindSetList[shmemSlotIndex].Add(RuntimeSlotKind::ADDRESS_EXPRESSION);
}
}
};
struct SlotInfo {
std::shared_ptr<LogicalTensor> tensor;
int refCount;
};
struct TensorSlotCheckpoint {
std::unordered_map<TensorSlot, SlotInfo> slotDict;
std::unordered_map<std::shared_ptr<LogicalTensor>, std::set<Operation*, LogicalTensor::CompareOp>> producerDict;
std::unordered_map<std::shared_ptr<LogicalTensor>, std::set<Operation*, LogicalTensor::CompareOp>> consumerDict;
};
struct TensorSlotUsage {
Function* construct{nullptr};
Function* destruct{nullptr};
Function* readFirst{nullptr};
Function* readLast{nullptr};
Function* writeFirst{nullptr};
Function* writeLast{nullptr};
};
struct TensorSlotManager {
std::vector<std::shared_ptr<TensorSlotScope>> scopeList;
std::shared_ptr<TensorSlotScope> currScope;
std::unordered_map<TensorSlot, int> slotIndexDict;
std::unordered_map<TensorSlot, TensorSlotUsage> slotUsageDict;
std::unordered_set<TensorSlot> liveSlotSet;
std::unordered_set<TensorSlot> assembleSlotSet;
std::unordered_set<TensorSlot> shmemTensorSlotSet;
std::unordered_map<std::string, TensorSlot> symbolNameDict;
std::unordered_map<TensorSlot, std::string> slotNameDict;
std::unordered_map<TensorSlot, std::string> slotFuncNameDict;
std::vector<TensorSlot> inputSlotList;
std::unordered_map<TensorSlot, int> inputSlotDict;
std::vector<std::string> inputNameList;
std::vector<TensorSlot> outputSlotList;
std::unordered_map<TensorSlot, int> outputSlotDict;
std::vector<std::string> outputNameList;
std::unordered_map<TensorSlot, TensorSlot> inplaceDict;
std::unordered_map<std::string, int> nameDict;
std::unordered_map<TensorSlot, TensorSlot> reshapeInplaceDict;
std::set<int> partialUpdateSlotIndexSet;
std::vector<TensorSlotCheckpoint> checkpointStack;
std::unordered_set<TensorSlot> recycleSlotSet;
void SetRecording(bool isRecording);
void BeginScope(Function* tensorFunc);
std::shared_ptr<TensorSlotScope> EndScope();
void ConnectSlot(std::shared_ptr<TensorSlotScope> scope);
void TensorSlotRead(const TensorSlot& slot, const std::shared_ptr<LogicalTensor>& tensor);
void TensorSlotWrite(const TensorSlot& slot, const std::shared_ptr<LogicalTensor>& tensor);
void TensorSlotConstruct(const TensorSlot& slot);
void TensorSlotDestruct(const TensorSlot& slot);
void TensorRead(const Tensor& tensor);
void TensorWrite(const Tensor& tensor, SlotProperty property = SlotProperty::NONE);
void TensorConstruct(const Tensor& tensor);
void TensorDestruct(const Tensor& tensor);
void TensorSymbol(const Tensor& tensor, const std::string& symbolName);
TensorSlotUsage& GetTensorSlotUsage(const TensorSlot& slot);
std::vector<int> LookupSlotIndex(const std::vector<std::reference_wrapper<Tensor>>& tensorList);
std::vector<int> LookupSlotIndexConst(const std::vector<std::reference_wrapper<const Tensor>>& tensorList);
std::vector<int> LookupSlotIndexBySymbol(const std::vector<std::string>& symbolNameList);
void MarkInput(const Tensor& tensor);
void MarkOutput(const Tensor& tensor);
void MarkInplace(const Tensor& out, const Tensor& in);
const std::vector<std::string>& GetInputNameList() const { return inputNameList; }
const std::vector<std::string>& GetOutputNameList() const { return outputNameList; }
int GetInputIndex(const Tensor& tensor);
int GetOutputIndex(const Tensor& tensor);
int GetSlotIndex(const Tensor& tensor);
void Checkpoint();
void Restore();
void UpdateReshapeInplaceSlots(IncastOutcastLink& link);
void SetSameSlot(const Tensor& operand, const Tensor& dst);
IncastOutcastLink BuildIncastOutcastLink(const std::string& rawname = "");
static bool HasSameSlot(const std::vector<int>& slots1, const std::vector<int>& slots2);
std::string Dump() const;
private:
void LogOperation(const TensorSlot& slot, const std::string& op);
void InsertLiveSlot(const TensorSlot& slot);
void TensorSlotRecycle(const TensorSlot& slot);
bool isRecording_{false};
};
}