/**
* Copyright (c) 2025-2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef PASS_OSP_SARKAR_TPP
#define PASS_OSP_SARKAR_TPP
namespace npu::tile_fwk {
namespace osp {
template <typename IntegralType>
IntegralType IntSqrtFloor(IntegralType num)
{
static_assert(std::is_integral_v<IntegralType>);
if (num <= 0) {
return 0;
}
constexpr IntegralType numberTwo = 2;
constexpr IntegralType numberFour = numberTwo * numberTwo;
IntegralType sqrt = 1;
IntegralType numCopy = num;
while (numCopy >= numberFour) {
sqrt *= numberTwo;
numCopy /= numberFour;
}
IntegralType power2 = sqrt / numberTwo;
while (power2 > 0) {
IntegralType sum = sqrt + power2;
if (sum * sum <= num) {
sqrt = sum;
}
power2 /= numberTwo;
}
return sqrt;
}
template <typename IntegralType>
std::vector<IntegralType> DivisorsList(IntegralType num)
{
static_assert(std::is_integral_v<IntegralType>);
if (num == 0) {
return std::vector<IntegralType>({0});
} else if (num < 0) {
return std::vector<IntegralType>();
}
std::vector<IntegralType> divs;
const IntegralType ub = IntSqrtFloor<IntegralType>(num);
for (IntegralType div = 1; div <= ub; ++div) {
if (num % div == 0) {
divs.emplace_back(div);
}
}
constexpr std::size_t numberTwo = 2U;
const std::size_t beginIndx = divs.back() * divs.back() == num ? divs.size() - numberTwo : divs.size() - 1U;
for (std::size_t indx = beginIndx; indx != std::numeric_limits<std::size_t>::max(); --indx) {
divs.emplace_back(num / divs[indx]);
}
return divs;
}
template <typename GraphTIn, typename GraphTOut>
VertexIdxT<GraphTIn> Sarkar<GraphTIn, GraphTOut>::AllParentsContraction(
VWorkwT<GraphTIn> commCost, const GraphTIn &graph,
std::vector<std::vector<VertexIdxT<GraphTIn>>> &expansionMapOutput) const
{
using VertexType = VertexIdxT<GraphTIn>;
expansionMapOutput.clear();
const std::vector<VertexIdxT<GraphTIn>> vertexPoset = GetBotPosetMap(graph);
const std::vector<VWorkwT<GraphTIn>> topDist = GetTopDistance(commCost, graph);
const std::vector<VWorkwT<GraphTIn>> botDist = GetBotDistance(commCost, graph);
auto cmp = [](const std::pair<long, VertexType> &lhs, const std::pair<long, VertexType> &rhs) {
return (lhs.first > rhs.first) || ((lhs.first == rhs.first) && (lhs.second < rhs.second));
};
std::set<std::pair<long, VertexType>, decltype(cmp)> vertPriority(cmp);
for (const VertexType &groupFoot : graph.Vertices()) {
if (graph.InDegree(groupFoot) < 2) {
continue;
}
bool shouldSkip = false;
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
if (graph.VertexType(groupHead) != graph.VertexType(groupFoot)) {
shouldSkip = true;
break;
}
}
if (shouldSkip) {
continue;
}
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
if (vertexPoset[groupFoot] != vertexPoset[groupHead] + 1) {
shouldSkip = true;
break;
}
}
if (shouldSkip) {
continue;
}
VWorkwT<GraphTIn> combinedWeight = graph.VertexWorkWeight(groupFoot);
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
combinedWeight += graph.VertexWorkWeight(groupHead);
}
if (combinedWeight > params_.maxWeight_) {
continue;
}
VWorkwT<GraphTIn> maxPath = topDist[groupFoot] + botDist[groupFoot] - graph.VertexWorkWeight(groupFoot);
for (const VertexType &par : graph.Parents(groupFoot)) {
maxPath = std::max(maxPath, topDist[par] + botDist[par] - graph.VertexWorkWeight(par));
}
VWorkwT<GraphTIn> maxParentDist = 0;
VWorkwT<GraphTIn> maxChildDist = 0;
for (const VertexType &child : graph.Children(groupFoot)) {
maxChildDist = std::max(maxChildDist, botDist[child] + commCost);
}
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
for (const VertexType &chld : graph.Children(groupHead)) {
if (chld == groupFoot) {
continue;
}
maxChildDist = std::max(maxChildDist, botDist[chld] + commCost);
}
}
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
for (const VertexType &par : graph.Parents(groupHead)) {
maxParentDist = std::max(maxParentDist, topDist[par] + commCost);
}
}
VWorkwT<GraphTIn> newMaxPath = maxParentDist + maxChildDist + graph.VertexWorkWeight(groupFoot);
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
newMaxPath += graph.VertexWorkWeight(groupHead);
}
long savings = maxPath - newMaxPath;
if (savings + static_cast<long>(params_.leniency_ * static_cast<double>(maxPath)) >= 0) {
vertPriority.emplace(savings, groupFoot);
}
}
std::vector<bool> partitionedFlag(graph.NumVertices(), false);
VertexIdxT<GraphTIn> maxCorseningNum
= graph.NumVertices()
- static_cast<VertexIdxT<GraphTIn>>(
static_cast<double>(graph.NumVertices()) * params_.geomDecay_);
VertexIdxT<GraphTIn> counter = 0;
long minSave = std::numeric_limits<long>::lowest();
for (auto prioIter = vertPriority.begin(); prioIter != vertPriority.end(); prioIter++) {
const long &vertSave = prioIter->first;
const VertexType &groupFoot = prioIter->second;
// Iterations halt
if (vertSave < minSave) {
break;
}
// Check whether we can glue
if (partitionedFlag[groupFoot]) {
continue;
}
bool shouldSkip = false;
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
if (partitionedFlag[groupHead]) {
shouldSkip = true;
break;
}
}
if (shouldSkip) {
continue;
}
// Adding to partition
std::vector<VertexType> part;
part.reserve(1 + graph.InDegree(groupFoot));
part.emplace_back(groupFoot);
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
part.emplace_back(groupHead);
}
expansionMapOutput.emplace_back(std::move(part));
counter += static_cast<VertexIdxT<GraphTIn>>(graph.InDegree(groupFoot));
if (counter > maxCorseningNum) {
minSave = vertSave;
}
partitionedFlag[groupFoot] = true;
for (const VertexType &groupHead : graph.Parents(groupFoot)) {
partitionedFlag[groupHead] = true;
}
}
for (const VertexType &vert : graph.Vertices()) {
if (partitionedFlag[vert]) continue;
expansionMapOutput.emplace_back(std::initializer_list<VertexType>{vert});
}
return counter;
}
template <typename GraphTIn, typename GraphTOut>
std::vector<std::vector<VertexIdxT<GraphTIn>>>
Sarkar<GraphTIn, GraphTOut>::GenerateVertexExpansionMap(
const GraphTIn &dagIn, VertexIdxT<GraphTIn> &diff)
{
std::vector<std::vector<VertexIdxT<GraphTIn>>> expansionMap;
switch (params_.mode_) {
case sarkar_params::Mode::LINES:
{
diff = SingleContraction(params_.commCost_, dagIn, expansionMap);
} break;
case sarkar_params::Mode::FAN_IN_FULL:
{
diff = AllParentsContraction(params_.commCost_, dagIn, expansionMap);
} break;
case sarkar_params::Mode::FAN_IN_PARTIAL:
{
diff = SomeParentsContraction(params_.commCost_, dagIn, expansionMap);
} break;
case sarkar_params::Mode::FAN_OUT_FULL:
{
diff = AllChildrenContraction(params_.commCost_, dagIn, expansionMap);
} break;
case sarkar_params::Mode::FAN_OUT_PARTIAL:
{
diff = SomeChildrenContraction(params_.commCost_, dagIn, expansionMap);
} break;
case sarkar_params::Mode::LEVEL_EVEN:
case sarkar_params::Mode::LEVEL_ODD:
{
diff = LevelContraction(params_.commCost_, dagIn, expansionMap);
} break;
case sarkar_params::Mode::FAN_IN_BUFFER:
case sarkar_params::Mode::FAN_OUT_BUFFER:
case sarkar_params::Mode::HOMOGENEOUS_BUFFER:
{
diff = HomogeneousBufferMerge(params_.commCost_, dagIn, expansionMap);
} break;
default:
{
diff = 0;
} break;
}
return expansionMap;
}
template <typename GraphTIn, typename GraphTOut>
std::vector<std::vector<VertexIdxT<GraphTIn>>>
Sarkar<GraphTIn, GraphTOut>::GenerateVertexExpansionMap(const GraphTIn &dagIn)
{
VertexIdxT<GraphTIn> dummy;
return GenerateVertexExpansionMap(dagIn, dummy);
}
template <typename GraphTIn, typename GraphTOut>
VertexIdxT<GraphTIn> Sarkar<GraphTIn, GraphTOut>::SomeChildrenContraction(
VWorkwT<GraphTIn> commCost, const GraphTIn &graph,
std::vector<std::vector<VertexIdxT<GraphTIn>>> &expansionMapOutput) const
{
using VertexType = VertexIdxT<GraphTIn>;
expansionMapOutput.clear();
const std::vector<VertexIdxT<GraphTIn>> vertexPoset = GetTopNodeDistance<GraphTIn, VertexIdxT<GraphTIn>>(graph);
const std::vector<VWorkwT<GraphTIn>> topDist = GetTopDistance(commCost, graph);
const std::vector<VWorkwT<GraphTIn>> botDist = GetBotDistance(commCost, graph);
auto cmp = [](const std::pair<long, std::vector<VertexType>> &lhs,
const std::pair<long, std::vector<VertexType>> &rhs)
{
return (lhs.first > rhs.first) || ((lhs.first == rhs.first) && (lhs.second < rhs.second));
};
std::set<std::pair<long, std::vector<VertexType>>, decltype(cmp)> vertPriority(cmp);
for (const VertexType &groupHead : graph.Vertices()) {
if (graph.OutDegree(groupHead) < 2) {
continue;
}
auto cmpChld = [&topDist, &botDist](const VertexType &lhs, const VertexType &rhs) {
return (topDist[lhs] < topDist[rhs]) || ((topDist[lhs] == topDist[rhs]) && (botDist[lhs] > botDist[rhs]))
|| ((topDist[lhs] == topDist[rhs]) && (botDist[lhs] == botDist[rhs]) && (lhs < rhs));
};
std::set<VertexType, decltype(cmpChld)> childrenPriority(cmpChld);
for (const VertexType &chld : graph.Children(groupHead)) {
if (vertexPoset[chld] == vertexPoset[groupHead] + 1) {
childrenPriority.emplace(chld);
}
}
if (childrenPriority.size() < 2) {
continue;
}
std::vector<std::pair<typename std::set<VertexType, decltype(cmpChld)>::const_iterator,
typename std::set<VertexType, decltype(cmpChld)>::const_iterator>>
admissbleChildrenGroups;
for (auto chldIterStart = childrenPriority.cbegin(); chldIterStart != childrenPriority.cend();) {
if (graph.VertexType(groupHead) != graph.VertexType(*chldIterStart)) {
++chldIterStart;
continue;
}
const VWorkwT<GraphTIn> tDist = topDist[*chldIterStart];
const VWorkwT<GraphTIn> bDist = botDist[*chldIterStart];
auto chldIterEnd = chldIterStart;
while (chldIterEnd != childrenPriority.cend()
&& tDist == topDist[*chldIterEnd] && bDist == botDist[*chldIterEnd]) {
if (graph.VertexType(groupHead) != graph.VertexType(*chldIterEnd)) {
break;
}
++chldIterEnd;
}
admissbleChildrenGroups.emplace_back(chldIterStart, chldIterEnd);
chldIterStart = chldIterEnd;
}
std::vector<VertexType> contractionEnsemble;
std::set<VertexType> contractionChildrenSet;
contractionEnsemble.reserve(1 + graph.OutDegree(groupHead));
contractionEnsemble.emplace_back(groupHead);
VWorkwT<GraphTIn> addedWeight = graph.VertexWorkWeight(groupHead);
for (std::size_t i = 0U; i < admissbleChildrenGroups.size(); ++i) {
const auto &first = admissbleChildrenGroups[i].first;
const auto &last = admissbleChildrenGroups[i].second;
for (auto it = first; it != last; ++it) {
contractionEnsemble.emplace_back(*it);
contractionChildrenSet.emplace(*it);
addedWeight += graph.VertexWorkWeight(*it);
}
if (addedWeight > params_.maxWeight_) break;
VWorkwT<GraphTIn> maxPath = 0;
for (const VertexType &vert : contractionEnsemble) {
maxPath = std::max(maxPath, botDist[vert] + topDist[vert] - graph.VertexWorkWeight(vert));
}
VWorkwT<GraphTIn> maxParentDist = 0;
VWorkwT<GraphTIn> maxChildDist = 0;
for (const VertexType &vert : contractionEnsemble) {
for (const VertexType &par : graph.Parents(vert)) {
if (par == groupHead) {
continue;
}
maxParentDist = std::max(maxParentDist, topDist[par] + commCost);
}
}
for (const VertexType &chld : graph.Children(groupHead)) {
if (contractionChildrenSet.find(chld) == contractionChildrenSet.end()) {
maxChildDist = std::max(maxChildDist, botDist[chld] + commCost);
}
}
for (std::size_t j = 1; j < contractionEnsemble.size(); j++) {
for (const VertexType &chld : graph.Children(contractionEnsemble[j])) {
maxChildDist = std::max(maxChildDist, botDist[chld] + commCost);
}
}
VWorkwT<GraphTIn> newMaxPath = maxChildDist + maxParentDist;
for (const VertexType &vert : contractionEnsemble) {
newMaxPath += graph.VertexWorkWeight(vert);
}
const long savings = static_cast<long>(maxPath) - static_cast<long>(newMaxPath);
if (savings + static_cast<long>(params_.leniency_ * static_cast<double>(maxPath)) >= 0) {
vertPriority.emplace(savings, contractionEnsemble);
}
}
}
std::vector<bool> partitionedFlag(graph.NumVertices(), false);
std::vector<bool> partitionedHeadFlag(graph.NumVertices(), false);
VertexIdxT<GraphTIn> maxCorseningNum
= graph.NumVertices()
- static_cast<VertexIdxT<GraphTIn>>(
static_cast<double>(graph.NumVertices()) * params_.geomDecay_);
VertexIdxT<GraphTIn> counter = 0;
long minSave = std::numeric_limits<long>::lowest();
for (auto prioIter = vertPriority.begin(); prioIter != vertPriority.end(); prioIter++) {
const long &vertSave = prioIter->first;
const VertexType &groupHead = prioIter->second.front();
const std::vector<VertexType> &contractionEnsemble = prioIter->second;
// Iterations halt
if (vertSave < minSave) {
break;
}
// Check whether we can glue
bool shouldSkip = false;
for (const VertexType &vert : contractionEnsemble) {
if (partitionedFlag[vert]) {
shouldSkip = true;
break;
}
}
if (shouldSkip) {
continue;
}
for (const VertexType &chld : graph.Children(groupHead)) {
if ((std::find(contractionEnsemble.cbegin(), contractionEnsemble.cend(), chld)
== contractionEnsemble.cend())
&& (vertexPoset[chld] == vertexPoset[groupHead] + 1)) {
if ((partitionedFlag[chld]) && (!partitionedHeadFlag[chld])) {
shouldSkip = true;
break;
}
}
}
if (shouldSkip) {
continue;
}
// Adding to partition
expansionMapOutput.emplace_back(contractionEnsemble);
counter += static_cast<VertexIdxT<GraphTIn>>(contractionEnsemble.size()) - 1;
if (counter > maxCorseningNum) {
minSave = vertSave;
}
partitionedHeadFlag[groupHead] = true;
for (const VertexType &vert : contractionEnsemble) {
partitionedFlag[vert] = true;
}
}
for (const VertexType &vert : graph.Vertices()) {
if (partitionedFlag[vert]) continue;
expansionMapOutput.emplace_back(std::initializer_list<VertexType>{vert});
}
return counter;
}
template <typename GraphTIn, typename GraphTOut>
VertexIdxT<GraphTIn> Sarkar<GraphTIn, GraphTOut>::SomeParentsContraction(
VWorkwT<GraphTIn> commCost, const GraphTIn &graph,
std::vector<std::vector<VertexIdxT<GraphTIn>>> &expansionMapOutput) const
{
using VertexType = VertexIdxT<GraphTIn>;
expansionMapOutput.clear();
const std::vector<VertexIdxT<GraphTIn>> vertexPoset = GetBotPosetMap(graph);
const std::vector<VWorkwT<GraphTIn>> topDist = GetTopDistance(commCost, graph);
const std::vector<VWorkwT<GraphTIn>> botDist = GetBotDistance(commCost, graph);
auto cmp = [](const std::pair<long, std::vector<VertexType>> &lhs,
const std::pair<long, std::vector<VertexType>> &rhs)
{
return (lhs.first > rhs.first) || ((lhs.first == rhs.first) && (lhs.second < rhs.second));
};
std::set<std::pair<long, std::vector<VertexType>>, decltype(cmp)> vertPriority(cmp);
for (const VertexType &groupFoot : graph.Vertices()) {
if (graph.InDegree(groupFoot) < 2) {
continue;
}
auto cmpPar = [&topDist, &botDist](const VertexType &lhs, const VertexType &rhs) {
return (botDist[lhs] < botDist[rhs]) || ((botDist[lhs] == botDist[rhs]) && (topDist[lhs] > topDist[rhs]))
|| ((botDist[lhs] == botDist[rhs]) && (topDist[lhs] == topDist[rhs]) && (lhs < rhs));
};
std::set<VertexType, decltype(cmpPar)> parentsPriority(cmpPar);
for (const VertexType &par : graph.Parents(groupFoot)) {
if (vertexPoset[par] + 1 == vertexPoset[groupFoot]) {
parentsPriority.emplace(par);
}
}
if (parentsPriority.size() < 2) {
continue;
}
std::vector<std::pair<typename std::set<VertexType, decltype(cmpPar)>::const_iterator,
typename std::set<VertexType, decltype(cmpPar)>::const_iterator>>
admissbleParentGroups;
for (auto parIterStart = parentsPriority.cbegin(); parIterStart != parentsPriority.cend();) {
if (graph.VertexType(groupFoot) != graph.VertexType(*parIterStart)) {
++parIterStart;
continue;
}
const VWorkwT<GraphTIn> tDist = topDist[*parIterStart];
const VWorkwT<GraphTIn> bDist = botDist[*parIterStart];
auto parIterEnd = parIterStart;
while (parIterEnd != parentsPriority.cend()
&& tDist == topDist[*parIterEnd] && bDist == botDist[*parIterEnd]) {
if (graph.VertexType(groupFoot) != graph.VertexType(*parIterEnd)) {
break;
}
++parIterEnd;
}
admissbleParentGroups.emplace_back(parIterStart, parIterEnd);
parIterStart = parIterEnd;
}
std::vector<VertexType> contractionEnsemble;
std::set<VertexType> contractionParentsSet;
contractionEnsemble.reserve(1 + graph.InDegree(groupFoot));
contractionEnsemble.emplace_back(groupFoot);
VWorkwT<GraphTIn> addedWeight = graph.VertexWorkWeight(groupFoot);
for (std::size_t i = 0U; i < admissbleParentGroups.size(); ++i) {
const auto &first = admissbleParentGroups[i].first;
const auto &last = admissbleParentGroups[i].second;
for (auto it = first; it != last; ++it) {
contractionEnsemble.emplace_back(*it);
contractionParentsSet.emplace(*it);
addedWeight += graph.VertexWorkWeight(*it);
}
if (addedWeight > params_.maxWeight_) break;
VWorkwT<GraphTIn> maxPath = 0;
for (const VertexType &vert : contractionEnsemble) {
maxPath = std::max(maxPath, topDist[vert] + botDist[vert] - graph.VertexWorkWeight(vert));
}
VWorkwT<GraphTIn> maxParentDist = 0;
VWorkwT<GraphTIn> maxChildDist = 0;
for (const VertexType &vert : contractionEnsemble) {
for (const VertexType &chld : graph.Children(vert)) {
if (chld == groupFoot) {
continue;
}
maxChildDist = std::max(maxChildDist, botDist[chld] + commCost);
}
}
for (const VertexType &par : graph.Parents(groupFoot)) {
if (contractionParentsSet.find(par) == contractionParentsSet.end()) {
maxParentDist = std::max(maxParentDist, topDist[par] + commCost);
}
}
for (std::size_t j = 1; j < contractionEnsemble.size(); j++) {
for (const VertexType &par : graph.Parents(contractionEnsemble[j])) {
maxParentDist = std::max(maxParentDist, topDist[par] + commCost);
}
}
VWorkwT<GraphTIn> newMaxPath = maxParentDist + maxChildDist;
for (const VertexType &vert : contractionEnsemble) {
newMaxPath += graph.VertexWorkWeight(vert);
}
long savings = static_cast<long>(maxPath) - static_cast<long>(newMaxPath);
if (savings + static_cast<long>(params_.leniency_ * static_cast<double>(maxPath)) >= 0) {
vertPriority.emplace(savings, contractionEnsemble);
}
}
}
std::vector<bool> partitionedFlag(graph.NumVertices(), false);
std::vector<bool> partitionedFootFlag(graph.NumVertices(), false);
VertexIdxT<GraphTIn> maxCorseningNum
= graph.NumVertices()
- static_cast<VertexIdxT<GraphTIn>>(
static_cast<double>(graph.NumVertices()) * params_.geomDecay_);
VertexIdxT<GraphTIn> counter = 0;
long minSave = std::numeric_limits<long>::lowest();
for (auto prioIter = vertPriority.begin(); prioIter != vertPriority.end(); prioIter++) {
const long &vertSave = prioIter->first;
const VertexType &groupFoot = prioIter->second.front();
const std::vector<VertexType> &contractionEnsemble = prioIter->second;
// Iterations halt
if (vertSave < minSave) break;
// Check whether we can glue
bool shouldSkip = std::any_of(contractionEnsemble.cbegin(),
contractionEnsemble.end(),
[&partitionedFlag](const auto &v) { return partitionedFlag[v]; });
if (shouldSkip) continue;
for (const VertexType &par : graph.Parents(groupFoot)) {
if ((std::find(contractionEnsemble.cbegin(), contractionEnsemble.cend(), par) == contractionEnsemble.cend())
&& (vertexPoset[par] + 1 == vertexPoset[groupFoot])) {
if ((partitionedFlag[par]) && (!partitionedFootFlag[par])) {
shouldSkip = true;
break;
}
}
}
if (shouldSkip) continue;
// Adding to partition
expansionMapOutput.emplace_back(contractionEnsemble);
counter += static_cast<VertexIdxT<GraphTIn>>(contractionEnsemble.size()) - 1;
if (counter > maxCorseningNum) {
minSave = vertSave;
}
partitionedFootFlag[groupFoot] = true;
for (const VertexType &vert : contractionEnsemble) {
partitionedFlag[vert] = true;
}
}
for (const VertexType &vert : graph.Vertices()) {
if (partitionedFlag[vert]) continue;
expansionMapOutput.emplace_back(std::initializer_list<VertexType>{vert});
}
return counter;
}
template <typename GraphTIn, typename GraphTOut>
VertexIdxT<GraphTIn> Sarkar<GraphTIn, GraphTOut>::LevelContraction(
VWorkwT<GraphTIn> commCost, const GraphTIn &graph,
std::vector<std::vector<VertexIdxT<GraphTIn>>> &expansionMapOutput) const
{
using VertexType = VertexIdxT<GraphTIn>;
expansionMapOutput.clear();
const std::vector<VertexIdxT<GraphTIn>> vertexPoset
= params_.useTopPoset_ ? GetTopNodeDistance<GraphTIn, VertexIdxT<GraphTIn>>(graph) : GetBotPosetMap(graph);
const std::vector<VWorkwT<GraphTIn>> topDist = GetTopDistance(commCost, graph);
const std::vector<VWorkwT<GraphTIn>> botDist = GetBotDistance(commCost, graph);
auto cmp = [](const std::pair<long, std::vector<VertexType>> &lhs,
const std::pair<long, std::vector<VertexType>> &rhs)
{
return (lhs.first > rhs.first) || ((lhs.first == rhs.first) && (lhs.second < rhs.second));
};
std::set<std::pair<long, std::vector<VertexType>>, decltype(cmp)> vertPriority(cmp);
const VertexIdxT<GraphTIn> minLevel = vertexPoset.size() == 0U
? 0U : *std::min_element(vertexPoset.cbegin(), vertexPoset.cend());
const VertexIdxT<GraphTIn> maxLevel = vertexPoset.size() == 0U
? 0U : *std::max_element(vertexPoset.cbegin(), vertexPoset.cend());
const VertexIdxT<GraphTIn> parity = params_.mode_ == sarkar_params::Mode::LEVEL_EVEN ? 0 : 1;
std::vector<std::vector<VertexIdxT<GraphTIn>>> levels(maxLevel - minLevel + 1);
for (const VertexType &vert : graph.Vertices()) {
levels[vertexPoset[vert] - minLevel].emplace_back(vert);
}
for (VertexIdxT<GraphTIn> headLevel = minLevel + parity; headLevel < maxLevel; headLevel += 2) {
const VertexIdxT<GraphTIn> footLevel = headLevel + 1;
const std::vector<VertexIdxT<GraphTIn>> &headVertices = levels[headLevel - minLevel];
const std::vector<VertexIdxT<GraphTIn>> &footVertices = levels[footLevel - minLevel];
UnionFindUniverse<VertexType, std::size_t, VWorkwT<GraphTIn>> uf;
for (const VertexType &vert : headVertices) {
uf.AddObject(vert, graph.VertexWorkWeight(vert));
}
for (const VertexType &vert : footVertices) {
uf.AddObject(vert, graph.VertexWorkWeight(vert));
}
for (const VertexType &srcVert : headVertices) {
for (const VertexType &tgtVert : graph.Children(srcVert)) {
if (vertexPoset[tgtVert] != footLevel) {
continue;
}
if (graph.VertexType(srcVert) != graph.VertexType(tgtVert)) {
continue;
}
uf.JoinByName(srcVert, tgtVert);
}
}
std::vector<std::vector<VertexType>> components = uf.GetConnectedComponents();
for (std::vector<VertexType> &comp : components) {
if (comp.size() < 2) {
continue;
}
if (uf.GetWeightOfComponentByName(comp.at(0)) > params_.maxWeight_) {
continue;
}
std::sort(comp.begin(), comp.end());
VWorkwT<GraphTIn> maxPath = std::numeric_limits<VWorkwT<GraphTIn>>::lowest();
for (const VertexType &vert : comp) {
maxPath = std::max(maxPath, topDist[vert] + botDist[vert] - graph.VertexWorkWeight(vert));
}
VWorkwT<GraphTIn> maxParentDist = 0;
for (const VertexType &vert : comp) {
for (const VertexType &par : graph.Parents(vert)) {
if (std::binary_search(comp.cbegin(), comp.cend(), par)) {
continue;
}
maxParentDist = std::max(maxParentDist, topDist[par] + commCost);
}
}
VWorkwT<GraphTIn> maxChildDist = 0;
for (const VertexType &vert : comp) {
for (const VertexType &chld : graph.Children(vert)) {
if (std::binary_search(comp.cbegin(), comp.cend(), chld)) {
continue;
}
maxChildDist = std::max(maxChildDist, botDist[chld] + commCost);
}
}
VWorkwT<GraphTIn> newMaxPath = maxParentDist + maxChildDist;
for (const VertexType &vert : comp) {
newMaxPath += graph.VertexWorkWeight(vert);
}
long savings = static_cast<long>(maxPath) - static_cast<long>(newMaxPath);
if (savings + static_cast<long>(params_.leniency_ * static_cast<double>(maxPath)) >= 0) {
vertPriority.emplace(savings, comp);
}
}
}
std::vector<bool> partitionedFlag(graph.NumVertices(), false);
VertexIdxT<GraphTIn> maxCorseningNum
= graph.NumVertices()
- static_cast<VertexIdxT<GraphTIn>>(
static_cast<double>(graph.NumVertices()) * params_.geomDecay_);
VertexIdxT<GraphTIn> counter = 0;
long minSave = std::numeric_limits<long>::lowest();
for (auto prioIter = vertPriority.cbegin(); prioIter != vertPriority.cend(); prioIter++) {
const long &compSave = prioIter->first;
const std::vector<VertexType> &comp = prioIter->second;
// Iterations halt
if (compSave < minSave) {
break;
}
// Check whether we can glue
bool shouldSkipHead = false;
bool shouldSkipFoot = false;
for (const VertexType &vert : comp) {
if (((vertexPoset[vert] - minLevel - parity) % 2) == 0) { // head vertex
for (const VertexType &chld : graph.Children(vert)) {
if ((vertexPoset[chld] == vertexPoset[vert] + 1) && partitionedFlag[chld]) {
shouldSkipHead = true;
}
}
} else { // foot vertex
for (const VertexType &par : graph.Parents(vert)) {
if ((vertexPoset[par] + 1 == vertexPoset[vert]) && partitionedFlag[par]) {
shouldSkipFoot = true;
}
}
}
}
if (shouldSkipHead && shouldSkipFoot) {
continue;
}
// Adding to partition
expansionMapOutput.emplace_back(comp);
counter += static_cast<VertexIdxT<GraphTIn>>(comp.size() - 1);
if (counter > maxCorseningNum) {
minSave = compSave;
}
for (const VertexType &vert : comp) {
partitionedFlag[vert] = true;
}
}
expansionMapOutput.reserve(graph.NumVertices() - counter);
for (const VertexType &vert : graph.Vertices()) {
if (partitionedFlag[vert]) continue;
expansionMapOutput.emplace_back(std::initializer_list<VertexType>{vert});
}
return counter;
}
template <typename GraphTIn, typename GraphTOut>
std::vector<std::size_t> Sarkar<GraphTIn, GraphTOut>::ComputeNodeHashes(
const GraphTIn &graph,
const std::vector<VertexIdxT<GraphTIn>> &vertexPoset,
const std::vector<VWorkwT<GraphTIn>> &dist) const
{
using VertexType = VertexIdxT<GraphTIn>;
std::vector<std::size_t> hashes(graph.NumVertices());
for (const VertexType &vert : graph.Vertices()) {
std::size_t &hash = hashes[vert];
hash = std::hash<VWorkwT<GraphTIn>>{}(graph.VertexWorkWeight(vert));
HashCombine(hash, vertexPoset[vert]);
HashCombine(hash, dist[vert]);
HashCombine(hash, graph.VertexType(vert));
}
return hashes;
}
template <typename GraphTIn, typename GraphTOut>
std::vector<std::size_t> Sarkar<GraphTIn, GraphTOut>::HomogeneousMerge(const std::size_t number,
const std::size_t minSize,
const std::size_t maxSize) const
{
std::size_t bestDiv = 1U;
const std::size_t minSizeAtLeastOne = minSize > 1U ? minSize : 1U;
const std::size_t maxSizeAtLeastOne = maxSize > 1U ? maxSize : 1U;
for (const std::size_t div : DivisorsList(number)) {
if (div > maxSizeAtLeastOne) {
continue;
}
if (div < minSizeAtLeastOne && bestDiv < div) {
bestDiv = div;
}
if (div >= minSizeAtLeastOne && ((bestDiv < minSizeAtLeastOne) || (div < bestDiv))) {
bestDiv = div;
}
}
if (bestDiv != 1U) {
return std::vector<std::size_t>(number / bestDiv, bestDiv);
}
std::size_t bestScore = 0U;
std::size_t bestBins = number / minSizeAtLeastOne;
std::size_t bins = (number / maxSizeAtLeastOne) > 2U ? (number / maxSizeAtLeastOne) : 2U;
for (; bins <= number / minSizeAtLeastOne; ++bins) {
if (number % bins == 0U && number != bins) {
return std::vector<std::size_t>(bins, number / bins);
}
std::size_t score = std::min(DivisorsList(number / bins).size(), DivisorsList((number / bins) + 1).size());
if (score >= bestScore) {
bestScore = score;
bestBins = bins;
}
}
std::size_t remainder = number % bestBins;
std::size_t size = number / bestBins;
std::vector<std::size_t> groups;
for (std::size_t i = 0U; i < bestBins; ++i) {
if (remainder != 0U) {
groups.emplace_back(size + 1U);
--remainder;
} else {
groups.emplace_back(size);
}
}
return groups;
}
template <typename GraphTIn, typename GraphTOut>
VertexIdxT<GraphTIn> Sarkar<GraphTIn, GraphTOut>::HomogeneousBufferMerge(
VWorkwT<GraphTIn> commCost, const GraphTIn &graph,
std::vector<std::vector<VertexIdxT<GraphTIn>>> &expansionMapOutput) const
{
using VertexType = VertexIdxT<GraphTIn>;
expansionMapOutput.clear();
const std::vector<VertexIdxT<GraphTIn>> vertexTopPoset = GetTopNodeDistance<GraphTIn, VertexIdxT<GraphTIn>>(graph);
const std::vector<VertexIdxT<GraphTIn>> vertexBotPoset = GetBotPosetMap(graph);
const std::vector<VWorkwT<GraphTIn>> topDist = GetTopDistance(commCost, graph);
const std::vector<VWorkwT<GraphTIn>> botDist = GetBotDistance(commCost, graph);
std::vector<std::size_t> hashValuesCombined(graph.NumVertices(), 1729U);
if (params_.mode_ == sarkar_params::Mode::FAN_OUT_BUFFER
|| params_.mode_ == sarkar_params::Mode::HOMOGENEOUS_BUFFER)
{
const std::vector<std::size_t> hashValues = ComputeNodeHashes(graph, vertexTopPoset, topDist);
std::vector<std::size_t> hashValuesWithParents = hashValues;
for (const VertexType &par : graph.Vertices()) {
for (const VertexType &chld : graph.Children(par)) {
HashCombine(hashValuesWithParents[chld], hashValues[par]);
}
}
for (const VertexType &vert : graph.Vertices()) {
HashCombine(hashValuesCombined[vert], hashValuesWithParents[vert]);
}
}
if (params_.mode_ == sarkar_params::Mode::FAN_IN_BUFFER
|| params_.mode_ == sarkar_params::Mode::HOMOGENEOUS_BUFFER)
{
const std::vector<std::size_t> hashValues = ComputeNodeHashes(graph, vertexBotPoset, botDist);
std::vector<std::size_t> hashValuesWithChildren = hashValues;
for (const VertexType &chld : graph.Vertices()) {
for (const VertexType &par : graph.Parents(chld)) {
HashCombine(hashValuesWithChildren[par], hashValues[chld]);
}
}
for (const VertexType &vert : graph.Vertices()) {
HashCombine(hashValuesCombined[vert], hashValuesWithChildren[vert]);
}
}
std::unordered_map<std::size_t, std::set<VertexType>> orbits;
for (const VertexType &vert : graph.Vertices()) {
if (graph.VertexWorkWeight(vert) > params_.smallWeightThreshold_) {
continue;
}
const std::size_t hash = hashValuesCombined[vert];
auto foundIter = orbits.find(hash);
if (foundIter == orbits.end()) {
orbits.emplace(std::piecewise_construct,
std::forward_as_tuple(hash),
std::forward_as_tuple(std::initializer_list<VertexIdxT<GraphTIn>>{vert}));
} else {
foundIter->second.emplace(vert);
}
}
VertexIdxT<GraphTIn> counter = 0;
std::vector<bool> partitionedFlag(graph.NumVertices(), false);
for (const VertexType &vert : graph.Vertices()) {
if (graph.VertexWorkWeight(vert) > params_.smallWeightThreshold_) {
continue;
}
if (partitionedFlag[vert]) {
continue;
}
const std::set<VertexType> &orb = orbits.at(hashValuesCombined[vert]);
if (orb.size() <= 1U) {
continue;
}
std::set<VertexType> parents;
if (params_.mode_ == sarkar_params::Mode::FAN_OUT_BUFFER
|| params_.mode_ == sarkar_params::Mode::HOMOGENEOUS_BUFFER)
{
for (const VertexType &par : graph.Parents(vert)) {
parents.emplace(par);
}
}
std::set<VertexType> children;
if (params_.mode_ == sarkar_params::Mode::FAN_IN_BUFFER
|| params_.mode_ == sarkar_params::Mode::HOMOGENEOUS_BUFFER)
{
for (const VertexType &chld : graph.Children(vert)) {
children.emplace(chld);
}
}
std::set<VertexType> secureOrb;
for (const VertexType &vertCandidate : orb) {
if (vertexTopPoset[vertCandidate] != vertexTopPoset[vert]) {
continue;
}
if (vertexBotPoset[vertCandidate] != vertexBotPoset[vert]) {
continue;
}
if (graph.VertexWorkWeight(vertCandidate) != graph.VertexWorkWeight(vert)) {
continue;
}
if (topDist[vertCandidate] != topDist[vert]) {
continue;
}
if (botDist[vertCandidate] != botDist[vert]) {
continue;
}
if (graph.VertexType(vertCandidate) != graph.VertexType(vert)) {
continue;
}
if (params_.mode_ == sarkar_params::Mode::FAN_OUT_BUFFER
|| params_.mode_ == sarkar_params::Mode::HOMOGENEOUS_BUFFER)
{
std::set<VertexType> candidateParents;
for (const VertexType &par : graph.Parents(vertCandidate)) {
candidateParents.emplace(par);
}
if (candidateParents != parents) {
continue;
}
}
if (params_.mode_ == sarkar_params::Mode::FAN_IN_BUFFER
|| params_.mode_ == sarkar_params::Mode::HOMOGENEOUS_BUFFER)
{
std::set<VertexType> candidateChildren;
for (const VertexType &chld : graph.Children(vertCandidate)) {
candidateChildren.emplace(chld);
}
if (candidateChildren != children) {
continue;
}
}
secureOrb.emplace(vertCandidate);
}
if (secureOrb.size() <= 1U) {
continue;
}
const VWorkwT<GraphTIn> desiredVerticesInGroup = graph.VertexWorkWeight(vert) == 0
? std::numeric_limits<VWorkwT<GraphTIn>>::lowest()
: params_.smallWeightThreshold_
/ graph.VertexWorkWeight(vert);
const VWorkwT<GraphTIn> maxVerticesInGroup = graph.VertexWorkWeight(vert) == 0
? std::numeric_limits<VWorkwT<GraphTIn>>::max()
: params_.maxWeight_ / graph.VertexWorkWeight(vert);
const std::size_t minDesiredSize
= desiredVerticesInGroup < 2 ? 2U : static_cast<std::size_t>(desiredVerticesInGroup);
const std::size_t maxDesiredSize
= std::max(minDesiredSize, std::min(minDesiredSize * 2U, static_cast<std::size_t>(maxVerticesInGroup)));
std::vector<std::size_t> groups = HomogeneousMerge(secureOrb.size(), minDesiredSize, maxDesiredSize);
auto secureOrbIter = secureOrb.begin();
for (std::size_t groupSize : groups) {
std::vector<VertexType> cluster;
for (std::size_t i = 0; i < groupSize; ++i) {
cluster.emplace_back(*secureOrbIter);
++secureOrbIter;
}
expansionMapOutput.emplace_back(std::move(cluster));
counter += static_cast<VertexType>(groupSize) - 1;
}
for (const VertexType &touchedVertex : secureOrb) {
partitionedFlag[touchedVertex] = true;
}
}
for (const VertexType &vert : graph.Vertices()) {
if (partitionedFlag[vert]) continue;
expansionMapOutput.emplace_back(std::initializer_list<VertexType>{vert});
}
return counter;
}
} // end namespace osp
} // namespace npu::tile_fwk
#endif // PASS_OSP_SARKAR_TPP