#include "NPUSwapManager.h"
#include <sstream>
#include <ATen/record_function.h>
#include <torch_npu/csrc/core/NPUStorageImpl.h>
#include <torch_npu/csrc/core/npu/NPUCachingAllocator.h>
#include <torch_npu/csrc/core/npu/NPUFunctions.h>
#include <torch_npu/csrc/core/npu/CachingHostAllocator.h>
#include <torch_npu/csrc/core/npu/NPUMemorySwap.h>
#include <torch_npu/csrc/framework/OpHook.h>
#include "swap_log.h"
#include "SwapException.h"
namespace c10_npu {
namespace swap {
SwapStage::SwapStage() : stageType(SwapStageType::INIT), microBatchIndex(0), layerIndex(0), modelIndex(0) {}
bool SwapStage::operator == (const SwapStage &other) const
{
return (stageType == other.stageType && microBatchIndex == other.microBatchIndex && layerIndex == other.layerIndex
&& modelIndex == other.modelIndex);
}
std::ostream &operator << (std::ostream &os, const SwapStage &obj)
{
os << "SwapStage: "
<< "stageType: " << static_cast<int>(obj.stageType) << " "
<< "microBatchIndex: " << obj.microBatchIndex << " "
<< "layerIndex: " << obj.layerIndex << " "
<< "modelIndex: " << obj.modelIndex << std::endl;
return os;
}
SwapConfig::SwapConfig()
: microBatchNum(0),
layerNum(0),
isOOM(false),
step(0),
oneStepDuration(0.0),
policyStep(0),
enableProfiler(false),
tensorSizeThresh(0),
enableExecutor(false),
enableCustomRecordStream(true)
{}
UniqueSwapPtr::UniqueSwapPtr() : ptrBase(0), index(0) {}
bool UniqueSwapPtr::operator == (const UniqueSwapPtr &other) const
{
return ptrBase == other.ptrBase && index == other.index;
}
std::ostream &operator << (std::ostream &os, const UniqueSwapPtr &obj)
{
os << "UniqueSwapPtr: "
<< "ptrBase: " << std::hex << obj.ptrBase << std::dec << " "
<< "index: " << obj.index << std::endl;
return os;
}
UniqueSwapPtr::operator std::string() const
{
std::stringstream ss;
ss << ptrBase << "_" << index;
return ss.str();
}
UniqueSwapMemory::UniqueSwapMemory() : allocated_bytes(0), reserved_bytes(0), active_bytes(0) {}
UniqueSwapMemory::UniqueSwapMemory(int64_t allocated_bytes, int64_t reserved_bytes, int64_t active_bytes)
: allocated_bytes(allocated_bytes), reserved_bytes(reserved_bytes), active_bytes(active_bytes)
{}
std::ostream &operator << (std::ostream &os, const UniqueSwapMemory &obj)
{
os << "UniqueSwapMemory: "
<< "allocated_bytes: " << obj.allocated_bytes << " "
<< "reserved_bytes: " << obj.reserved_bytes << " "
<< "active_bytes: " << obj.active_bytes << std::endl;
return os;
}
ProfilerTensorInfo::ProfilerTensorInfo(const at::Tensor &tensor)
{
this->ptr = NPUSwapManager::GetInstance().getUniqueSwapPtr(tensor);
this->nbytes = tensor.storage().nbytes();
this->dtype = tensor.scalar_type();
auto tensorPtrTypeIter = NPUSwapManager::GetInstance().tensorPtrTypeMap.find(this->ptr);
if (tensorPtrTypeIter == NPUSwapManager::GetInstance().tensorPtrTypeMap.end()) {
this->tensorType = SwapTensorType::OTHERS;
} else {
this->tensorType = tensorPtrTypeIter->second;
}
for (int i = 0; i < tensor.sizes().size(); i++) {
this->shapeV2.push_back(tensor.sizes()[i]);
}
}
std::ostream &operator << (std::ostream &os, const ProfilerTensorInfo &obj)
{
os << "ProfilerTensorInfo: "
<< "ptr: " << obj.ptr << " "
<< "nbytes: " << obj.nbytes << " "
<< "dtype: " << obj.dtype << " "
<< "tensorType: " << static_cast<int>(obj.tensorType) << " "
<< "shape: " << obj.shapeV2 << std::endl;
return os;
}
ProfilerOpInfo::ProfilerOpInfo(int opId, std::string opName, int64_t allocated_bytes, int64_t reserved_bytes,
int64_t active_bytes)
: opId(opId), opName(opName), swapMemory(allocated_bytes, reserved_bytes, active_bytes)
{
this->stage = NPUSwapManager::GetInstance().config.stage;
this->step = NPUSwapManager::GetInstance().config.step;
}
std::ostream &operator << (std::ostream &os, const ProfilerOpInfo &obj)
{
os << "ProfilerOpInfo: "
<< "opId: " << obj.opId << " "
<< "opName: " << obj.opName << " "
<< "stage: " << obj.stage << " "
<< "step: " << obj.step << " "
<< "swapMemory: " << obj.swapMemory << std::endl;
for (auto &t : obj.profilerTensorInfoVec) {
os << t << std::endl;
}
return os;
}
void ProfilerOpInfo::appendTensorInfo(const at::Tensor &tensor)
{
profilerTensorInfoVec.emplace_back(ProfilerTensorInfo(tensor));
}
ProfilerSwapInfo::ProfilerSwapInfo(int opId, std::string swapName, size_t size, bool isOOM, UniqueSwapPtr srcDataPtr,
UniqueSwapPtr dstDataPtr)
: opId(opId), swapName(swapName), size(size), isOOM(isOOM), srcPtr(srcDataPtr), dstPtr(dstDataPtr)
{}
SwapProfiler::SwapProfiler() : isInit(false) {}
SwapProfiler::~SwapProfiler()
{
isInit = false;
}
int SwapProfiler::Init()
{
isInit = true;
lastOpId = 0;
return 0;
}
void SwapProfiler::updateStep()
{
profilerOpInfoMap[NPUSwapManager::GetInstance().config.step] = profilerOpInfoVec;
lastOpId = profilerOpInfoVec.back().opId;
profilerOpInfoVec.clear();
profilerSwapInfoVec.clear();
}
void SwapProfiler::appendOpInfo(std::string &opName, int &opId)
{
int device = 0;
SWAP_CHECK_ERROR(c10_npu::GetDevice(&device));
const c10_npu::NPUCachingAllocator::DeviceStats stats =
c10_npu::NPUCachingAllocator::allocator.load()->getDeviceStats(device);
ProfilerOpInfo profilerOpInfo(opId, opName,
stats.allocated_bytes[static_cast<size_t>(c10_npu::NPUCachingAllocator::StatType::AGGREGATE)].current,
stats.reserved_bytes[static_cast<size_t>(c10_npu::NPUCachingAllocator::StatType::AGGREGATE)].current,
stats.active_bytes[static_cast<size_t>(c10_npu::NPUCachingAllocator::StatType::AGGREGATE)].current);
profilerOpInfoVec.emplace_back(profilerOpInfo);
}
void SwapProfiler::ReportInfoV2(std::string &opName, int &opId, c10::SmallVector<at::Tensor, N> &tensors)
{
appendOpInfo(opName, opId);
ProfilerOpInfo &profilerOpInfo = profilerOpInfoVec.back();
for (const auto &tensor : tensors) {
profilerOpInfo.appendTensorInfo(tensor);
}
}
void SwapProfiler::ReportInfoV2(bool isSwapOut, at::DataPtr &srcDataPtr, at::DataPtr &dstDataPtr, size_t size,
bool isOOM)
{
std::string swapName = isSwapOut ? "swapOut" : "swapIn";
int opId = profilerOpInfoVec.empty() ? lastOpId : profilerOpInfoVec.back().opId;
ProfilerSwapInfo profilerSwapInfo(opId, swapName, size, isOOM,
NPUSwapManager::GetInstance().getUniqueSwapPtr(srcDataPtr.get()),
NPUSwapManager::GetInstance().getUniqueSwapPtr(dstDataPtr.get()));
profilerSwapInfoVec.emplace_back(profilerSwapInfo);
}
std::vector<ProfilerOpInfo> &SwapProfiler::getPolicyStepOpVec()
{
return profilerOpInfoMap[NPUSwapManager::GetInstance().config.policyStep];
}
SwapPolicyInfo::SwapPolicyInfo() : executorNeedMatch(false), swapOutOpId(0), swapInOpId(0) {}
std::ostream &operator << (std::ostream &os, const SwapPolicyInfo &obj)
{
os << "SwapPolicyInfo: "
<< "ptr: " << obj.ptr << " "
<< "swapOutOpId: " << obj.swapOutOpId << " "
<< "swapOutStage: " << obj.swapOutStage << " "
<< "swapInOpId: " << obj.swapInOpId << " "
<< "swapInStage.: " << obj.swapInStage << " "
<< "freeStage: " << obj.freeStage << " "
<< "swapInFreeStage.: " << obj.swapInFreeStage << std::endl;
return os;
}
ExecutorTensorInfo::ExecutorTensorInfo()
: opCount(0),
opTag(0),
dtype(at::ScalarType::Byte),
nbytes(0),
shape(0),
opCallsStack(0),
tensorIndexCallsStack(0)
{}
ExecutorTensorInfo::ExecutorTensorInfo(const at::Tensor &tensor, UniqueSwapPtr uniqueSwapPtr)
: ptr(uniqueSwapPtr),
opCount(0),
opTag(0),
dtype(tensor.scalar_type()),
opCallsStack(0),
tensorIndexCallsStack(0)
{
nbytes = tensor.storage().nbytes();
shape = convertShapeToInt64(tensor);
}
ExecutorTensorInfo::ExecutorTensorInfo(const SwapStage &s1, const SwapStage &s2)
: opCount(0),
opTag(0),
dtype(at::ScalarType::Byte),
nbytes(0),
shape(0),
opCallsStack(0),
tensorIndexCallsStack(0),
swapOutStage(s1),
swapInStage(s2)
{}
bool ExecutorTensorInfo::operator == (const ExecutorTensorInfo &other) const
{
return (opCount == other.opCount && opTag == other.opTag && dtype == other.dtype &&
opCallsStack == other.opCallsStack &&
tensorIndexCallsStack == other.tensorIndexCallsStack);
}
std::ostream &operator << (std::ostream &os, const ExecutorTensorInfo &obj)
{
os << "ExecutorTensorInfo: "
<< "ptr: " << obj.ptr << " "
<< "opCount: " << obj.opCount << " "
<< "opTag: " << obj.opTag << " "
<< "nbytes: " << obj.nbytes << " "
<< "shape: " << obj.shape << " "
<< "opCallsStack: " << obj.opCallsStack << " "
<< "tensorIndexCallsStack: " << obj.tensorIndexCallsStack << " "
<< "swapOutStage: " << obj.swapOutStage << " "
<< "swapInStage: " << obj.swapInStage << " "
<< "freeStage: " << obj.freeStage << " "
<< "swapInFreeStage: " << obj.swapInFreeStage << std::endl;
return os;
}
size_t ExecutorTensorInfo::convertShapeToInt64(const at::Tensor &tensor)
{
size_t res = 0;
for (auto s : tensor.sizes()) {
res = (res << 16) + s;
}
return res;
}
size_t ExecutorTensorInfo::convertShapeToInt64(const c10::SmallVector<size_t, N> &sizes)
{
size_t res = 0;
for (auto s : sizes) {
res = (res << 16) + s;
}
return res;
}
void ExecutorTensorInfo::initFromProfilerTensorInfo(const ProfilerTensorInfo &pti)
{
nbytes = pti.nbytes;
shape = convertShapeToInt64(pti.shapeV2);
dtype = pti.dtype;
ptr.ptrBase = pti.ptr.ptrBase;
ptr.index = pti.ptr.index;
}
void ExecutorTensorInfo::updateCallsStack(int opOneHot, int opIndex, int tensorIndex)
{
++opCount;
opTag |= opOneHot;
opCallsStack = (opCallsStack << 8) + opIndex;
tensorIndexCallsStack = (tensorIndexCallsStack << 8) + tensorIndex;
}
SwapExecutor::SwapExecutor() : isInit(false) {}
SwapExecutor::~SwapExecutor()
{
DeInit();
}
int SwapExecutor::Init()
{
if (isInit) {
return 0;
}
this->swapStreams.push_back(getNPUStreamFromPool(c10_npu::current_device()));
isInit = true;
return 0;
}
int SwapExecutor::DeInit()
{
if (!isInit) {
return 0;
}
isInit = false;
return 0;
}
int SwapExecutor::SwapOut(c10::intrusive_ptr<c10::StorageImpl> storageImplPtr, bool isOOM, SwapStage *freeStage)
{
at::DataPtr &dataPtrNpu = storageImplPtr->mutable_data_ptr();
if (dataPtrNpu.device().is_cpu()) {
SWAP_LOG_WARN("SwapOut tensor dataPtr is on cpu, skip.");
return 1;
}
uint64_t uniqueId = static_cast<torch_npu::NPUStorageImpl *>(storageImplPtr.get())->get_unique_id();
auto inEventIter = swapInEventMap.find(uniqueId);
if (inEventIter != swapInEventMap.end()) {
SWAP_LOG_WARN("SwapOut tensor need to process swapin wait task, skip.");
return 1;
}
auto swapOutStorageImpIter = swapOutStorageImplMap.find(uniqueId);
if (swapOutStorageImpIter != swapOutStorageImplMap.end()) {
SWAP_LOG_WARN("Tensor cannot be swapped out twice consecutively, skip.");
return 1;
}
RECORD_FUNCTION("swap_out", std::vector<c10::IValue>({}));
SWAP_LOG_INFO("SwapOut pre, storage uniqueId[%lu], mem ptr on npu[%p][%s]", uniqueId, storageImplPtr->data(),
std::string(NPUSwapManager::GetInstance().getUniqueSwapPtr(storageImplPtr->data())).c_str());
auto allocatorCPU = at_npu::native::getCachingHostAllocator();
size_t size = storageImplPtr->nbytes();
at::DataPtr dataPtrCpu = allocatorCPU->allocate(size);
NPUSwapManager::GetInstance().tensorPtrCountMap[reinterpret_cast<size_t>(dataPtrCpu.get())]++;
if (NPUSwapManager::GetInstance().config.enableProfiler) {
NPUSwapManager::GetInstance().ReportInfoToSwapProfiler(true, dataPtrNpu, dataPtrCpu, size, isOOM);
}
c10_npu::NPUStream &swapStream = this->swapStreams[0];
c10_npu::NPUEvent event;
event.record(c10_npu::getCurrentNPUStream());
event.block(swapStream);
c10_npu::NPUStream currentStream = c10_npu::getCurrentNPUStream();
c10_npu::setCurrentNPUStream(swapStream);
at_npu::native::memory_swap(dataPtrCpu.get(), size, dataPtrNpu.get(), size, 2);
c10_npu::setCurrentNPUStream(currentStream);
if (!NPUSwapManager::GetInstance().config.enableCustomRecordStream) {
c10_npu::NPUCachingAllocator::allocator.load()->recordStream(dataPtrNpu, swapStream);
}
dataPtrCpu.unsafe_set_device(dataPtrNpu.device());
if (NPUSwapManager::GetInstance().config.enableCustomRecordStream) {
NPUSwapManager::GetInstance().RecordStream(dataPtrNpu, swapStream, freeStage);
}
storageImplPtr->set_data_ptr_noswap(std::move(dataPtrCpu));
SWAP_LOG_INFO("SwapOut post, storage uniqueId[%lu], mem ptr on cpu[%p][%s]", uniqueId, storageImplPtr->data(),
std::string(NPUSwapManager::GetInstance().getUniqueSwapPtr(storageImplPtr->data())).c_str());
swapOutStorageImplMap.insert(std::make_pair(uniqueId, c10::weak_intrusive_ptr<c10::StorageImpl>(storageImplPtr)));
return 0;
}
int SwapExecutor::SwapOut(const at::Tensor &tensor, SwapStage *freeStage)
{
c10::intrusive_ptr<c10::StorageImpl> storageImplPtr = tensor.storage().getWeakStorageImpl().lock();
if (!storageImplPtr) {
return 1;
}
return SwapOut(storageImplPtr, false, freeStage);
}
int SwapExecutor::SwapIn(uint64_t uniqueId, bool needWait)
{
auto outTensorIter = swapOutStorageImplMap.find(uniqueId);
if (outTensorIter == swapOutStorageImplMap.end()) {
return 1;
}
c10::intrusive_ptr<c10::StorageImpl> storageImplPtr = outTensorIter->second.lock();
if (!storageImplPtr) {
SWAP_LOG_INFO(
"SwapIn pre: StorageImpl of the tensor for current SwapIn is already destructed since the tensor would \
not be used anymore. swapOutStorageImplMap.find(uniqueId[%lu])->second.weak_count[%zu], use_count[%zu]",
uniqueId, outTensorIter->second.weak_use_count(), outTensorIter->second.use_count());
swapOutStorageImplMap.erase(outTensorIter);
return 1;
}
RECORD_FUNCTION("swap_in", std::vector<c10::IValue>({}));
c10_npu::NPUStream &swapStream = this->swapStreams[0];
c10_npu::NPUEvent beforeSwapInEvent;
beforeSwapInEvent.record(c10_npu::getCurrentNPUStream());
beforeSwapInEvent.block(swapStream);
at::DataPtr &dataPtrCpu = storageImplPtr->mutable_data_ptr();
SWAP_LOG_INFO("SwapIn pre, storage uniqueId[%lu], mem ptr on cpu[%p][%s]", uniqueId, storageImplPtr->data(),
std::string(NPUSwapManager::GetInstance().getUniqueSwapPtr(storageImplPtr->data())).c_str());
auto allocatorNPU = c10_npu::NPUCachingAllocator::allocator.load();
size_t size = storageImplPtr->nbytes();
at::DataPtr dataPtrNpu = allocatorNPU->allocate(size);
if (NPUSwapManager::GetInstance().config.enableProfiler) {
NPUSwapManager::GetInstance().ReportInfoToSwapProfiler(false, dataPtrCpu, dataPtrNpu, size);
}
c10_npu::NPUStream currentStream = c10_npu::getCurrentNPUStream();
c10_npu::setCurrentNPUStream(swapStream);
at_npu::native::memory_swap(dataPtrNpu.get(), size, dataPtrCpu.get(), size, 1);
c10_npu::setCurrentNPUStream(currentStream);
if (NPUSwapManager::GetInstance().config.enableCustomRecordStream) {
auto it = uniqueIdToSwapInFreeStageMap.find(uniqueId);
if (it != uniqueIdToSwapInFreeStageMap.end()) {
NPUSwapManager::GetInstance().RecordStream(storageImplPtr, swapStream, &(it->second));
uniqueIdToSwapInFreeStageMap.erase(it);
} else {
NPUSwapManager::GetInstance().RecordStream(storageImplPtr, swapStream);
}
NPUSwapManager::GetInstance().RecordStream(dataPtrCpu, swapStream);
} else {
c10_npu::NPUCachingAllocator::allocator.load()->recordStream(dataPtrNpu, swapStream);
at_npu::native::CachingHostAllocator_recordEvent(dataPtrCpu.get(), aclrtMemcpyKind::ACL_MEMCPY_HOST_TO_DEVICE, swapStream);
}
storageImplPtr->set_data_ptr_noswap(std::move(dataPtrNpu));
SWAP_LOG_INFO("SwapIn post, storage uniqueId[%lu], mem ptr on npu[%p][%s]", uniqueId, storageImplPtr->data(),
std::string(NPUSwapManager::GetInstance().getUniqueSwapPtr(storageImplPtr->data())).c_str());
c10_npu::NPUEvent afterSwapInEvent;
afterSwapInEvent.record(swapStream);
swapOutStorageImplMap.erase(outTensorIter);
swapInEventMap.insert(std::make_pair(uniqueId, std::move(afterSwapInEvent)));
if (needWait) {
SwapInWait(uniqueId);
}
return 0;
}
int SwapExecutor::SwapInWait(uint64_t uniqueId)
{
auto inEventIter = swapInEventMap.find(uniqueId);
if (inEventIter == swapInEventMap.end()) {
return 1;
}
RECORD_FUNCTION("swap_in_wait", std::vector<c10::IValue>({}));
SWAP_LOG_INFO("SwapIn wait, storage uniqueId[%lu]", uniqueId);
inEventIter->second.block(c10_npu::getCurrentNPUStream());
swapInEventMap.erase(inEventIter);
return 0;
}
void SwapExecutor::CheckAndInsertStorageToMap(const at::Tensor &src, const at::Tensor &dst)
{
uint64_t uniqueIdSrc =
static_cast<torch_npu::NPUStorageImpl *>(src.storage().unsafeGetStorageImpl())->get_unique_id();
auto tensorIter = swapOutStorageImplMap.find(uniqueIdSrc);
if (tensorIter == swapOutStorageImplMap.end()) {
return;
}
c10::intrusive_ptr<c10::StorageImpl> storageImplPtrDst = dst.storage().getWeakStorageImpl().lock();
if (!storageImplPtrDst) {
return;
}
uint64_t uniqueIdDst =
static_cast<torch_npu::NPUStorageImpl *>(dst.storage().unsafeGetStorageImpl())->get_unique_id();
swapOutStorageImplMap.insert(
std::make_pair(uniqueIdDst, c10::weak_intrusive_ptr<c10::StorageImpl>(storageImplPtrDst)));
SWAP_LOG_INFO("Insert storage to SwapOutStorageImplMap, uniqueId[%lu], mem ptr on cpu[%p][%s]", uniqueIdDst,
storageImplPtrDst->data(),
std::string(NPUSwapManager::GetInstance().getUniqueSwapPtr(storageImplPtrDst->data())).c_str());
}
bool SwapExecutor::needGenerateTensorInfo(const at::Tensor &tensor)
{
if (tensor.nbytes() < NPUSwapManager::GetInstance().config.tensorSizeThresh) {
return false;
}
return true;
}
void SwapExecutor::initOpNameToOneHotAndIndexMap(const std::vector<std::string> &opNames)
{
opNameToOneHotAndIndexMap.clear();
size_t oneHot = 1;
size_t opIndex = 1;
for (const auto &opName : opNames) {
opNameToOneHotAndIndexMap[opName] = std::make_pair(oneHot, opIndex);
oneHot = oneHot << 1;
opIndex += 1;
}
}
bool SwapExecutor::checkMatchAndSwapOut(ExecutorTensorInfo &eti, std::vector<ExecutorTensorInfo *> &candidateSwapOutVec)
{
int matchCount = 0;
for (auto it = candidateSwapOutVec.rbegin(); matchCount < candidateSwapOutVec.size() &&
it != candidateSwapOutVec.rend(); ++it) {
++matchCount;
if ((NPUSwapManager::GetInstance().config.stage == (*it)->swapOutStage) && (*(*it)) == eti) {
eti.swapInStage = (*it)->swapInStage;
eti.freeStage = (*it)->freeStage;
eti.swapInFreeStage = (*it)->swapInFreeStage;
candidateSwapOutVec.erase(std::next(it).base());
return true;
}
}
return false;
}
void SwapExecutor::initStanderdSwapOutVec(std::vector<ExecutorTensorInfo *> &standerdSwapOutVec,
const std::vector<ProfilerOpInfo> &opInfosVec, const std::vector<SwapPolicyInfo> &policyInfosVec)
{
for (const auto &policyInfo : policyInfosVec) {
if (!policyInfo.executorNeedMatch) {
continue;
}
ExecutorTensorInfo *eti = new ExecutorTensorInfo(policyInfo.swapOutStage, policyInfo.swapInStage);
for (const auto &opInfoIter : opInfosVec) {
if (opInfoIter.opId > policyInfo.swapOutOpId) {
break;
}
std::pair<size_t, size_t> oneHotAndIndex = GetOpOneHotAndIndex(opInfoIter.opName);
int tensorIndex = 0;
for (const auto &tensorInfoIter : opInfoIter.profilerTensorInfoVec) {
if (tensorInfoIter.ptr == policyInfo.ptr) {
if (eti->opCount == 0) {
eti->initFromProfilerTensorInfo(tensorInfoIter);
eti->swapInStage = policyInfo.swapInStage;
eti->swapOutStage = policyInfo.swapOutStage;
eti->freeStage = policyInfo.freeStage;
eti->swapInFreeStage = policyInfo.swapInFreeStage;
}
eti->updateCallsStack(oneHotAndIndex.first, oneHotAndIndex.second, tensorIndex);
}
tensorIndex++;
}
}
standerdSwapOutVec.push_back(eti);
}
}
void SwapExecutor::initCandidateOptimPolicyVec(const std::vector<SwapPolicyInfo> &policyInfosVec)
{
for (const auto &policyInfo : policyInfosVec) {
if (policyInfo.executorNeedMatch) {
continue;
}
candidateOptimPolicyVec.emplace_back(policyInfo);
}
}
void SwapExecutor::processOptimTask(std::unordered_map<UniqueSwapPtr, c10::weak_intrusive_ptr<c10::StorageImpl>,
HashUniqueSwapPtr> &tensorPtrWeakPtrMap)
{
for (const auto &policyInfo : candidateOptimPolicyVec) {
auto weakPtr = tensorPtrWeakPtrMap.find(policyInfo.ptr);
if (weakPtr != tensorPtrWeakPtrMap.end()) {
auto storageImplPtr = weakPtr->second.lock();
if (!storageImplPtr) {
continue;
}
auto tensorToSwapOutVecIter = stageToSwapOutMap
.try_emplace(policyInfo.swapOutStage,
c10::SmallVector<c10::weak_intrusive_ptr<c10::StorageImpl>, N>())
.first;
tensorToSwapOutVecIter->second.push_back(weakPtr->second);
auto tensorToSwapInVecIter = stageToSwapInMap
.try_emplace(policyInfo.swapInStage,
c10::SmallVector<c10::weak_intrusive_ptr<c10::StorageImpl>, N>())
.first;
tensorToSwapInVecIter->second.push_back(weakPtr->second);
auto iter = stageToOptimFreeStageMap.try_emplace(policyInfo.swapOutStage, std::vector<SwapStage>()).first;
iter->second.push_back(policyInfo.freeStage);
uint64_t uniqueId =
static_cast<torch_npu::NPUStorageImpl *>(storageImplPtr.get())->get_unique_id();
uniqueIdToSwapInFreeStageMap[uniqueId] = policyInfo.swapInFreeStage;
}
}
}
std::pair<size_t, size_t> SwapExecutor::GetOpOneHotAndIndex(const std::string &opName)
{
auto it = opNameToOneHotAndIndexMap.find(opName);
if (it != opNameToOneHotAndIndexMap.end()) {
return it->second;
}
return std::pair<size_t, size_t>(0, 0);
}
void SwapExecutor::ProcessTensorMatchTask(const std::string &opName, const c10::SmallVector<at::Tensor, N> &curTensors)
{
if (candidateSwapOutVec.empty()) {
return;
}
std::pair<size_t, size_t> oneHotAndIndex = GetOpOneHotAndIndex(opName);
int tensorIndex = 0;
for (const auto &tensor : curTensors) {
if (needGenerateTensorInfo(tensor)) {
UniqueSwapPtr uniqueSwapPtr = NPUSwapManager::GetInstance().getUniqueSwapPtr(tensor);
auto executorTensorInfoIter = ptrToTensorInfoMap.find(uniqueSwapPtr);
if (executorTensorInfoIter == ptrToTensorInfoMap.end()) {
executorTensorInfoIter =
ptrToTensorInfoMap.try_emplace(uniqueSwapPtr, ExecutorTensorInfo(tensor, uniqueSwapPtr)).first;
}
(executorTensorInfoIter->second).updateCallsStack(oneHotAndIndex.first, oneHotAndIndex.second, tensorIndex);
if (checkMatchAndSwapOut(executorTensorInfoIter->second, candidateSwapOutVec)) {
SwapOut(tensor, &(executorTensorInfoIter->second.freeStage));
auto tensorToSwapInVecIter = stageToSwapInMap
.try_emplace((executorTensorInfoIter->second).swapInStage,
c10::SmallVector<c10::weak_intrusive_ptr<c10::StorageImpl>, N>())
.first;
tensorToSwapInVecIter->second.push_back(tensor.storage().getWeakStorageImpl());
uint64_t uniqueId = static_cast<torch_npu::NPUStorageImpl *>(tensor.storage().unsafeGetStorageImpl())
->get_unique_id();
uniqueIdToSwapInFreeStageMap[uniqueId] = executorTensorInfoIter->second.swapInFreeStage;
}
}
tensorIndex++;
}
}
void SwapExecutor::ProcessStageMatchTask(const SwapStage ¤tStage)
{
auto itOut = stageToSwapOutMap.find(currentStage);
if (itOut != stageToSwapOutMap.end()) {
auto tempIter = stageToOptimFreeStageMap.find(currentStage);
int count = 0;
for (auto &storageImpl : itOut->second) {
auto storageImplPtr = storageImpl.lock();
if (!storageImplPtr) {
count++;
continue;
}
SwapOut(storageImplPtr, false, &(tempIter->second[count++]));
}
stageToSwapOutMap.erase(itOut);
stageToOptimFreeStageMap.erase(tempIter);
}
auto itIn = stageToSwapInMap.find(currentStage);
if (itIn != stageToSwapInMap.end()) {
for (auto storageImpl = itIn->second.rbegin(); storageImpl != itIn->second.rend(); ++storageImpl) {
SwapIn(*storageImpl);
}
stageToSwapInMap.erase(itIn);
}
}
void SwapExecutor::clearStanderdSwapOutVec()
{
for (auto it = standerdSwapOutVec.begin(); it != standerdSwapOutVec.end(); ++it) {
delete *it;
}
standerdSwapOutVec.clear();
}
void SwapExecutor::clearCandidateOptimPolicyVec()
{
candidateOptimPolicyVec.clear();
}
void SwapExecutor::SwapIn(c10::weak_intrusive_ptr<c10::StorageImpl> &storageImplWeakPtr)
{
auto storageImplPtr = storageImplWeakPtr.lock();
if (!storageImplPtr) {
return;
}
uint64_t uniqueId = static_cast<torch_npu::NPUStorageImpl *>(storageImplPtr.get())->get_unique_id();
SwapIn(uniqueId, false);
}
void SwapExecutor::SwapOut(c10::weak_intrusive_ptr<c10::StorageImpl> &storageImplWeakPtr)
{
auto storageImplPtr = storageImplWeakPtr.lock();
if (!storageImplPtr) {
return;
}
SwapOut(storageImplPtr);
}
void SwapExecutor::updateStep(std::unordered_map<UniqueSwapPtr, c10::weak_intrusive_ptr<c10::StorageImpl>,
HashUniqueSwapPtr> &tensorPtrWeakPtrMap)
{
ptrToTensorInfoMap.clear();
candidateSwapOutVec.clear();
candidateSwapOutVec.resize(standerdSwapOutVec.size());
std::reverse_copy(standerdSwapOutVec.begin(), standerdSwapOutVec.end(), candidateSwapOutVec.begin());
processOptimTask(tensorPtrWeakPtrMap);
}
template <class T> RecordStreamManager<T>::RecordStreamManager() : isInit(false) {}
template <class T> RecordStreamManager<T>::~RecordStreamManager()
{
isInit = false;
}
template <class T> int RecordStreamManager<T>::Init()
{
if (isInit) {
return 0;
}
isInit = true;
return 0;
}
template <class T> int RecordStreamManager<T>::DeInit()
{
isInit = false;
return 0;
}
template <class T> void RecordStreamManager<T>::RecordStream(T &ptr, c10_npu::NPUStream stream)
{
if (!isInit) {
return;
}
c10_npu::NPUEvent recordStreamEvent;
recordStreamEvent.record(stream);
recordedQueue.push_back(std::make_pair(std::move(ptr), std::move(recordStreamEvent)));
}
template <class T> void RecordStreamManager<T>::ProcessEvent()
{
if (!isInit) {
return;
}
while (!recordedQueue.empty()) {
auto &recordStreamEvent = recordedQueue.front().second;
if (recordStreamEvent.query()) {
recordedQueue.pop_front();
} else {
break;
}
}
}
template <class T> bool RecordStreamManager<T>::ProcessMallocEvent()
{
if (!isInit) {
return false;
}
bool res = false;
while (!recordedQueue.empty()) {
auto &recordStreamEvent = recordedQueue.front().second;
recordStreamEvent.block(c10_npu::getCurrentNPUStream());
recordedQueue.pop_front();
res = true;
}
return res;
}
template <class T> RecordStreamWithFreeStageManager<T>::RecordStreamWithFreeStageManager() : isInit(false) {}
template <class T> RecordStreamWithFreeStageManager<T>::~RecordStreamWithFreeStageManager()
{
isInit = false;
}
template <class T> int RecordStreamWithFreeStageManager<T>::Init()
{
if (isInit) {
return 0;
}
isInit = true;
return 0;
}
template <class T> int RecordStreamWithFreeStageManager<T>::DeInit()
{
isInit = false;
return 0;
}
template <class T>
void RecordStreamWithFreeStageManager<T>::RecordStream(T &ptr, c10_npu::NPUStream stream, SwapStage &freeStage)
{
if (!isInit) {
return;
}
c10_npu::NPUEvent recordStreamEvent;
recordStreamEvent.record(stream);
auto stageToFreeIter =
StageToFreeEventMap.try_emplace(freeStage, std::deque<std::pair<T, c10_npu::NPUEvent>>()).first;
stageToFreeIter->second.push_back(std::make_pair(std::move(ptr), std::move(recordStreamEvent)));
}
template <class T> void RecordStreamWithFreeStageManager<T>::ProcessEvent()
{
if (!isInit) {
return;
}
for (const auto &pair : StageToFreeEventMap) {
const SwapStage &stage = pair.first;
const std::pair<T, c10_npu::NPUEvent> &recordedQueue = pair.second;
while (!recordedQueue.empty()) {
auto &recordStreamEvent = recordedQueue.front().second;
if (recordStreamEvent.query()) {
recordedQueue.pop_front();
} else {
break;
}
}
}
}
template <class T> bool RecordStreamWithFreeStageManager<T>::FreeEventWithStage(SwapStage &freeStage)
{
if (!isInit) {
return false;
}
bool res = false;
auto stageToFreeIter = StageToFreeEventMap.find(freeStage);
if (stageToFreeIter == StageToFreeEventMap.end()) {
return false;
}
auto &recordedQueue = stageToFreeIter->second;
while (!recordedQueue.empty()) {
auto &recordStreamEvent = recordedQueue.front().second;
recordStreamEvent.block(c10_npu::getCurrentNPUStream());
recordedQueue.pop_front();
res = true;
}
return res;
}
template <class T> bool RecordStreamWithFreeStageManager<T>::ProcessMallocEvent()
{
if (!isInit) {
return false;
}
bool res = false;
for (auto &pair : StageToFreeEventMap) {
const SwapStage &stage = pair.first;
auto &recordedQueue = pair.second;
while (!recordedQueue.empty()) {
auto &recordStreamEvent = recordedQueue.front().second;
recordStreamEvent.block(c10_npu::getCurrentNPUStream());
recordedQueue.pop_front();
res = true;
}
}
return res;
}
NPUSwapManager::NPUSwapManager()
: swap_enable(false),
swap_oom_enable(false),
isInit(false),
executor(nullptr),
profiler(nullptr),
opId(0),
recordedDataPtrManager(nullptr),
recordedStorageImplManager(nullptr),
recordedDataPtrWithFreeStageManager(nullptr),
recordedStorageImplWithFreeStageManager(nullptr)
{}
NPUSwapManager::~NPUSwapManager()
{
DeInit();
}
NPUSwapManager &NPUSwapManager::GetInstance()
{
static NPUSwapManager instance;
return instance;
}
int NPUSwapManager::Init()
{
if (isInit) {
return 0;
}
if (executor == nullptr) {
executor = new SwapExecutor();
if (executor != nullptr) {
executor->Init();
}
}
if (profiler == nullptr) {
profiler = new SwapProfiler();
if (profiler != nullptr) {
profiler->Init();
}
}
if (recordedDataPtrManager == nullptr) {
recordedDataPtrManager = new RecordStreamManager<at::DataPtr>();
if (recordedDataPtrManager != nullptr) {
recordedDataPtrManager->Init();
}
}
if (recordedStorageImplManager == nullptr) {
recordedStorageImplManager = new RecordStreamManager<c10::intrusive_ptr<c10::StorageImpl>>();
if (recordedStorageImplManager != nullptr) {
recordedStorageImplManager->Init();
}
}
if (recordedDataPtrWithFreeStageManager == nullptr) {
recordedDataPtrWithFreeStageManager = new RecordStreamWithFreeStageManager<at::DataPtr>();
if (recordedDataPtrWithFreeStageManager != nullptr) {
recordedDataPtrWithFreeStageManager->Init();
}
}
if (recordedStorageImplWithFreeStageManager == nullptr) {
recordedStorageImplWithFreeStageManager =
new RecordStreamWithFreeStageManager<c10::intrusive_ptr<c10::StorageImpl>>();
if (recordedStorageImplWithFreeStageManager != nullptr) {
recordedStorageImplWithFreeStageManager->Init();
}
}
at_npu::native::RegisterOpHookBeginFn(
[](const std::string &op_name) -> void { c10_npu::swap::NPUSwapManager::GetInstance().BeginHook(op_name); });
at_npu::native::RegisterOpHookEndFn([]() -> void {
c10_npu::swap::NPUSwapManager::GetInstance().PostHook();
c10_npu::swap::NPUSwapManager::GetInstance().EndHook();
});
at_npu::native::RegisterOpHookPreFn([](const at::Tensor &at_tensor) -> void {
if (!at_tensor.defined()) {
return;
}
c10_npu::swap::NPUSwapManager::GetInstance().TensorHook(at_tensor);
});
at_npu::native::RegisterOpHookPostFn([](const at::Tensor &at_tensor) -> void {
if (!at_tensor.defined()) {
return;
}
c10_npu::swap::NPUSwapManager::GetInstance().TensorHook(at_tensor);
});
isInit = true;
return 0;
}
int NPUSwapManager::DeInit()
{
if (!isInit) {
return 0;
}
if (executor != nullptr) {
delete executor;
executor = nullptr;
}
if (profiler != nullptr) {
delete profiler;
profiler = nullptr;
}
if (recordedDataPtrManager != nullptr) {
delete recordedDataPtrManager;
recordedDataPtrManager = nullptr;
}
if (recordedStorageImplManager != nullptr) {
delete recordedStorageImplManager;
recordedStorageImplManager = nullptr;
}
if (recordedDataPtrWithFreeStageManager != nullptr) {
delete recordedDataPtrWithFreeStageManager;
recordedDataPtrWithFreeStageManager = nullptr;
}
if (recordedStorageImplWithFreeStageManager != nullptr) {
delete recordedStorageImplWithFreeStageManager;
recordedStorageImplWithFreeStageManager = nullptr;
}
isInit = false;
return 0;
}
void NPUSwapManager::RecordStream(at::DataPtr &dataPtr, c10_npu::NPUStream stream, SwapStage *freeStage)
{
if (!isInit) {
return;
}
if (freeStage == nullptr) {
recordedDataPtrManager->RecordStream(dataPtr, stream);
} else {
recordedDataPtrWithFreeStageManager->RecordStream(dataPtr, stream, *freeStage);
}
}
void NPUSwapManager::RecordStream(c10::intrusive_ptr<c10::StorageImpl> storageImpl, c10_npu::NPUStream stream,
SwapStage *freeStage)
{
if (!isInit) {
return;
}
if (freeStage == nullptr) {
recordedStorageImplManager->RecordStream(storageImpl, stream);
} else {
recordedStorageImplWithFreeStageManager->RecordStream(storageImpl, stream, *freeStage);
}
}
void NPUSwapManager::ProcessEvent()
{
if (!isInit) {
return;
}
recordedDataPtrManager->ProcessEvent();
recordedStorageImplManager->ProcessEvent();
}
bool NPUSwapManager::ProcessMallocEvent()
{
if (!isInit) {
return false;
}
if (!config.enableCustomRecordStream) {
return false;
}
bool res = recordedDataPtrManager->ProcessMallocEvent();
res = res || recordedDataPtrWithFreeStageManager->ProcessMallocEvent();
res = res || recordedStorageImplWithFreeStageManager->ProcessMallocEvent();
return res;
}
int NPUSwapManager::BeginHook(const std::string &opName)
{
if (!isInit) {
return 0;
}
SWAP_LOG_INFO("BeginHook in, opIdStk.size[%zu], opNameStk.size[%zu], curTensorsStk.size[%zu]", opIdStk.size(),
curOpNameStk.size(), curTensorsStk.size());
opIdStk.push_back(opId);
opId++;
curOpNameStk.push_back(opName);
c10::SmallVector<at::Tensor, N> curTensors;
curTensorsStk.push_back(curTensors);
ProcessEvent();
SWAP_LOG_INFO("BeginHook out, opId[%d], opName[%s], curTensors num[%zu]", opIdStk.back(),
curOpNameStk.back().c_str(), curTensorsStk.back().size());
return 0;
}
int NPUSwapManager::EndHook()
{
if (!isInit) {
return 0;
}
SWAP_LOG_INFO("EndHook in, opId[%d], opName[%s], curTensors num[%zu]", opIdStk.back(), curOpNameStk.back().c_str(),
curTensorsStk.back().size());
for (auto &tensor : curTensorsStk.back()) {
SaveTensor(tensor);
}
tensorValidMap.clear();
for (size_t i = 0; i < curTensorsStk.back().size(); ++i) {
SWAP_LOG_DEBUG(
"EndHook post, opId[%d], opName[%s], curTensors num[%zu], idx[%zu], storage uniqueId[%lu], mem ptr[%p][%s]",
opIdStk.back(), curOpNameStk.back().c_str(), curTensorsStk.back().size(), i,
static_cast<torch_npu::NPUStorageImpl *>(curTensorsStk.back()[i].storage().unsafeGetStorageImpl())
->get_unique_id(),
curTensorsStk.back()[i].storage().data(),
std::string(getUniqueSwapPtr(curTensorsStk.back()[i])).c_str());
}
opIdStk.pop_front();
curOpNameStk.pop_back();
curTensorsStk.pop_back();
SWAP_LOG_INFO("EndHook out, opIdStk.size[%zu], opNameStk.size[%zu], curTensorsStk.size[%zu]", opIdStk.size(),
curOpNameStk.size(), curTensorsStk.size());
return 0;
}
int NPUSwapManager::TensorHook(const at::Tensor &tensor)
{
if (!isInit) {
return 0;
}
if (tensor.numel() == 0) {
return 1;
}
if (!tensor.device().is_privateuseone()) {
return 1;
}
uint64_t uniqueId = static_cast<torch_npu::NPUStorageImpl *>(tensor.storage().unsafeGetStorageImpl())
->get_unique_id();
SWAP_LOG_INFO("TensorHook in, before process, opId[%d], opName[%s], curTensors num[%zu], storage uniqueId[%lu], "
"mem ptr[%p][%s]",
opIdStk.back(), curOpNameStk.back().c_str(), curTensorsStk.back().size(), uniqueId, tensor.storage().data(),
std::string(getUniqueSwapPtr(tensor)).c_str());
curTensorsStk.back().emplace_back(tensor);
tensorValidMap[tensor.storage().mutable_data()] = true;
executor->SwapInWait(uniqueId);
executor->SwapIn(uniqueId, true);
SWAP_LOG_INFO("TensorHook out, after process, opId[%d], opName[%s], curTensors num[%zu], storage uniqueId[%lu], "
"mem ptr[%p][%s]",
opIdStk.back(), curOpNameStk.back().c_str(), curTensorsStk.back().size(), uniqueId, tensor.storage().data(),
std::string(getUniqueSwapPtr(tensor)).c_str());
return 0;
}
int NPUSwapManager::TensorHook(const at::TensorList &tensor)
{
int status = 0;
for (int i = 0; i < tensor.size(); ++i) {
status = TensorHook(tensor[i]);
}
return status;
}
int NPUSwapManager::TensorHook(const std::vector<at::Tensor> &tensor)
{
int status = 0;
for (int i = 0; i < tensor.size(); ++i) {
status = TensorHook(tensor[i]);
}
return status;
}
int NPUSwapManager::PostHook()
{
if (!isInit) {
return 0;
}
SWAP_LOG_INFO("PostHook in, opId[%d], opName[%s], curTensors num[%zu]", opIdStk.back(), curOpNameStk.back().c_str(),
curTensorsStk.back().size());
for (size_t i = 0; i < curTensorsStk.back().size(); ++i) {
SWAP_LOG_DEBUG("PostHook before process, opId[%d], opName[%s], curTensors num[%zu], idx[%zu], storage \
uniqueId[%lu], mem ptr[%p][%s]",
opIdStk.back(), curOpNameStk.back().c_str(), curTensorsStk.back().size(), i,
static_cast<torch_npu::NPUStorageImpl *>(curTensorsStk.back()[i].storage().unsafeGetStorageImpl())
->get_unique_id(),
curTensorsStk.back()[i].storage().data(),
std::string(getUniqueSwapPtr(curTensorsStk.back()[i])).c_str());
}
if (config.enableProfiler) {
profiler->ReportInfoV2(curOpNameStk.back(), opIdStk.front(), curTensorsStk.back());
}
if (config.enableExecutor) {
executor->ProcessTensorMatchTask(curOpNameStk.back(), curTensorsStk.back());
executor->ProcessStageMatchTask(config.stage);
recordedDataPtrWithFreeStageManager->FreeEventWithStage(config.stage);
recordedStorageImplWithFreeStageManager->FreeEventWithStage(config.stage);
UpdateCurrentStagePerOp();
}
for (size_t i = 0; i < curTensorsStk.back().size(); ++i) {
SWAP_LOG_DEBUG("PostHook after process, opId[%d], opName[%s], curTensors num[%zu], idx[%zu], storage \
uniqueId[%lu], mem ptr[%p][%s]",
opIdStk.back(), curOpNameStk.back().c_str(), curTensorsStk.back().size(), i,
static_cast<torch_npu::NPUStorageImpl *>(curTensorsStk.back()[i].storage().unsafeGetStorageImpl())
->get_unique_id(),
curTensorsStk.back()[i].storage().data(),
std::string(getUniqueSwapPtr(curTensorsStk.back()[i])).c_str());
}
SWAP_LOG_INFO("PostHook out, opId[%d], opName[%s], curTensors num[%zu]", opIdStk.back(),
curOpNameStk.back().c_str(), curTensorsStk.back().size());
return 0;
}
void NPUSwapManager::SaveTensor(const at::Tensor &tensor)
{
if (!swap_oom_enable) {
return;
}
void *dataPtr = tensor.storage().mutable_data();
auto storageImplIter = storageImplMap.find(dataPtr);
if (storageImplIter == storageImplMap.end()) {
storageImplMap.emplace(dataPtr, tensor.storage().getWeakStorageImpl());
} else {
storageImplMap.erase(storageImplIter);
storageImplMap.emplace(dataPtr, tensor.storage().getWeakStorageImpl());
}
auto it =
std::find_if(tensorQueue.begin(), tensorQueue.end(), [&dataPtr](const void *ptr) { return ptr == dataPtr; });
if (it != tensorQueue.end()) {
tensorQueue.erase(it);
}
tensorQueue.push_back(dataPtr);
}
void NPUSwapManager::CheckAndSwapOutForOOM(void *ptrInBlock)
{
if (!swap_oom_enable) {
return;
}
auto storageImplIter = storageImplMap.find(ptrInBlock);
if (storageImplIter == storageImplMap.end()) {
return;
}
c10::intrusive_ptr<c10::StorageImpl> storageImplPtr = storageImplIter->second.lock();
if (storageImplPtr) {
auto validIter = tensorValidMap.find(ptrInBlock);
if (validIter == tensorValidMap.end()) {
auto blacklistIter = ptrBlacklist.find(getUniqueSwapPtr(storageImplPtr->mutable_data()));
if (blacklistIter == ptrBlacklist.end()) {
executor->SwapOut(storageImplPtr, true);
c10_npu::NPUStream &swapStream = executor->swapStreams[0];
swapStream.synchronize();
}
}
}
storageImplMap.erase(storageImplIter);
auto it = std::find_if(tensorQueue.begin(), tensorQueue.end(),
[&ptrInBlock](const void *ptr) { return ptr == ptrInBlock; });
if (it != tensorQueue.end()) {
tensorQueue.erase(it);
}
}
std::map<void *, c10::weak_intrusive_ptr<c10::StorageImpl>> &NPUSwapManager::GetStorageImplMap()
{
return storageImplMap;
}
std::deque<void *> &NPUSwapManager::GetTensorQueue()
{
return tensorQueue;
}
void NPUSwapManager::ReportInfoToSwapProfiler(bool isSwapOut, at::DataPtr &srcDataPtr, at::DataPtr &dstDataPtr,
size_t size, bool isOOM)
{
if (!isInit) {
return;
}
profiler->ReportInfoV2(isSwapOut, srcDataPtr, dstDataPtr, size, isOOM);
}
void NPUSwapManager::CheckAndInsertStorageToMap(const at::Tensor &src, const at::Tensor &dst)
{
if (!isInit) {
return;
}
executor->CheckAndInsertStorageToMap(src, dst);
}
UniqueSwapPtr NPUSwapManager::getUniqueSwapPtr(const at::Tensor &tensor)
{
size_t ptrBase = reinterpret_cast<size_t>(tensor.storage().data());
UniqueSwapPtr uniqueSwapPtr;
uniqueSwapPtr.ptrBase = ptrBase;
auto it = tensorPtrCountMap.find(ptrBase);
if (it == tensorPtrCountMap.end()) {
uniqueSwapPtr.index = 0;
} else {
uniqueSwapPtr.index = tensorPtrCountMap[ptrBase];
}
return uniqueSwapPtr;
}
UniqueSwapPtr NPUSwapManager::getUniqueSwapPtr(const void *storagePtr)
{
size_t ptrBase = reinterpret_cast<size_t>(storagePtr);
UniqueSwapPtr uniqueSwapPtr;
uniqueSwapPtr.ptrBase = ptrBase;
auto it = tensorPtrCountMap.find(ptrBase);
if (it == tensorPtrCountMap.end()) {
uniqueSwapPtr.index = 0;
} else {
uniqueSwapPtr.index = tensorPtrCountMap[ptrBase];
}
return uniqueSwapPtr;
}
UniqueSwapPtr NPUSwapManager::getUniqueSwapPtr(size_t p)
{
size_t ptrBase = p;
UniqueSwapPtr uniqueSwapPtr;
uniqueSwapPtr.ptrBase = ptrBase;
auto it = tensorPtrCountMap.find(ptrBase);
if (it == tensorPtrCountMap.end()) {
uniqueSwapPtr.index = 0;
} else {
uniqueSwapPtr.index = tensorPtrCountMap[ptrBase];
}
return uniqueSwapPtr;
}
std::vector<UniqueSwapPtr> NPUSwapManager::recordTensorPtrWithTypes(const std::vector<at::Tensor> &tensors,
SwapTensorType type, int updateWeakPtrMap, bool isUpdateBlacklist)
{
if (updateWeakPtrMap == 1) {
tensorPtrWeakPtrMap.clear();
}
std::vector<UniqueSwapPtr> results;
results.reserve(tensors.size());
for (const auto &tensor : tensors) {
auto uniquePtr = getUniqueSwapPtr(tensor);
tensorPtrTypeMap.try_emplace(uniquePtr, type);
if (updateWeakPtrMap > 0) {
tensorPtrWeakPtrMap.try_emplace(uniquePtr, tensor.storage().getWeakStorageImpl());
}
if (isUpdateBlacklist) {
ptrBlacklist.insert(uniquePtr);
}
results.emplace_back(uniquePtr);
}
return results;
}
void NPUSwapManager::initOpNameToOneHotAndIndexMap(std::vector<std::string> &frequentOpNames)
{
executor->initOpNameToOneHotAndIndexMap(frequentOpNames);
}
void NPUSwapManager::FunAfterProfiler(std::vector<SwapPolicyInfo> &policyInfoVec)
{
if (!isInit) {
return;
}
if (config.enableExecutor) {
executor->clearStanderdSwapOutVec();
executor->initStanderdSwapOutVec(executor->standerdSwapOutVec, profiler->getPolicyStepOpVec(), policyInfoVec);
executor->clearCandidateOptimPolicyVec();
executor->initCandidateOptimPolicyVec(policyInfoVec);
}
}
void NPUSwapManager::UpdateCurrentStagePerOp()
{
if (config.fwdOpLayerInfo.empty() || config.bwdOpLayerInfo.empty()) {
return;
}
config.currentStageOpId++;
auto mbId = config.stage.microBatchIndex;
auto modelId = config.stage.modelIndex;
auto numMicroBatches = config.fwdOpLayerInfo[modelId].size();
if (numMicroBatches == 1) {
mbId = 1;
modelId = 0;
}
auto opLayerInfo = config.fwdOpLayerInfo;
if (config.stage.stageType == SwapStageType::FWD) {
opLayerInfo = config.fwdOpLayerInfo;
} else if (config.stage.stageType == SwapStageType::BWD) {
opLayerInfo = config.bwdOpLayerInfo;
} else {
return;
}
if (config.currentStageOpId > opLayerInfo[modelId][mbId].back()) {
config.stage.layerIndex = opLayerInfo[modelId][mbId].size() + 1;
} else {
for (int i = 0; i < opLayerInfo[modelId][mbId].size(); i++) {
if (config.currentStageOpId <= opLayerInfo[modelId][mbId][i]) {
config.stage.layerIndex = i + 1;
break;
}
}
}
}
void NPUSwapManager::updateStep()
{
if (!isInit) {
return;
}
config.currentStageOpId = 0;
executor->updateStep(tensorPtrWeakPtrMap);
tensorQueue.clear();
config.isOOM = false;
}
c10_npu::NPUStream &NPUSwapManager::GetSwapStream()
{
return this->executor->swapStreams[0];
}
}
}