#include "polly/FlattenAlgo.h"
#include "polly/Support/ISLOStream.h"
#include "polly/Support/ISLTools.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "polly-flatten-algo"
using namespace polly;
using namespace llvm;
namespace {
bool isDimBoundedByConstant(isl::set Set, unsigned dim) {
auto ParamDims = unsignedFromIslSize(Set.dim(isl::dim::param));
Set = Set.project_out(isl::dim::param, 0, ParamDims);
Set = Set.project_out(isl::dim::set, 0, dim);
auto SetDims = unsignedFromIslSize(Set.tuple_dim());
assert(SetDims >= 1);
Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
return bool(Set.is_bounded());
}
bool isDimBoundedByParameter(isl::set Set, unsigned dim) {
Set = Set.project_out(isl::dim::set, 0, dim);
auto SetDims = unsignedFromIslSize(Set.tuple_dim());
assert(SetDims >= 1);
Set = Set.project_out(isl::dim::set, 1, SetDims - 1);
return bool(Set.is_bounded());
}
bool isVariableDim(const isl::basic_map &BMap) {
auto FixedVal = BMap.plain_get_val_if_fixed(isl::dim::out, 0);
return FixedVal.is_null() || FixedVal.is_nan();
}
bool isVariableDim(const isl::map &Map) {
for (isl::basic_map BMap : Map.get_basic_map_list())
if (isVariableDim(BMap))
return false;
return true;
}
bool isVariableDim(const isl::union_map &UMap) {
for (isl::map Map : UMap.get_map_list())
if (isVariableDim(Map))
return false;
return true;
}
isl::union_pw_aff subtract(isl::union_pw_aff UPwAff, isl::val Val) {
if (Val.is_zero())
return UPwAff;
auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
isl::stat Stat =
UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
auto ValAff =
isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
auto Subtracted = PwAff.sub(ValAff);
Result = Result.union_add(isl::union_pw_aff(Subtracted));
return isl::stat::ok();
});
if (Stat.is_error())
return {};
return Result;
}
isl::union_pw_aff multiply(isl::union_pw_aff UPwAff, isl::val Val) {
if (Val.is_one())
return UPwAff;
auto Result = isl::union_pw_aff::empty(UPwAff.get_space());
isl::stat Stat =
UPwAff.foreach_pw_aff([=, &Result](isl::pw_aff PwAff) -> isl::stat {
auto ValAff =
isl::pw_aff(isl::set::universe(PwAff.get_space().domain()), Val);
auto Multiplied = PwAff.mul(ValAff);
Result = Result.union_add(Multiplied);
return isl::stat::ok();
});
if (Stat.is_error())
return {};
return Result;
}
isl::union_map scheduleProjectOut(const isl::union_map &UMap, unsigned first,
unsigned n) {
if (n == 0)
return UMap;
have no effect on schedule ranges */
auto Result = isl::union_map::empty(UMap.ctx());
for (isl::map Map : UMap.get_map_list()) {
auto Outprojected = Map.project_out(isl::dim::out, first, n);
Result = Result.unite(Outprojected);
}
return Result;
}
isl::union_pw_aff scheduleExtractDimAff(isl::union_map UMap, unsigned pos) {
auto SingleUMap = isl::union_map::empty(UMap.ctx());
for (isl::map Map : UMap.get_map_list()) {
unsigned MapDims = unsignedFromIslSize(Map.range_tuple_dim());
assert(MapDims > pos);
isl::map SingleMap = Map.project_out(isl::dim::out, 0, pos);
SingleMap = SingleMap.project_out(isl::dim::out, 1, MapDims - pos - 1);
SingleUMap = SingleUMap.unite(SingleMap);
};
auto UAff = isl::union_pw_multi_aff(SingleUMap);
auto FirstMAff = isl::multi_union_pw_aff(UAff);
return FirstMAff.at(0);
}
isl::union_map tryFlattenSequence(isl::union_map Schedule) {
auto IslCtx = Schedule.ctx();
auto ScatterSet = isl::set(Schedule.range());
auto ParamSpace = Schedule.get_space().params();
auto Dims = unsignedFromIslSize(ScatterSet.tuple_dim());
assert(Dims >= 2u);
if (!isDimBoundedByConstant(ScatterSet, 0)) {
LLVM_DEBUG(dbgs() << "Abort; dimension is not of fixed size\n");
return {};
}
auto AllDomains = Schedule.domain();
auto AllDomainsToNull = isl::union_pw_multi_aff(AllDomains);
auto NewSchedule = isl::union_map::empty(ParamSpace.ctx());
auto Counter = isl::pw_aff(isl::local_space(ParamSpace.set_from_params()));
while (!ScatterSet.is_empty()) {
LLVM_DEBUG(dbgs() << "Next counter:\n " << Counter << "\n");
LLVM_DEBUG(dbgs() << "Remaining scatter set:\n " << ScatterSet << "\n");
auto ThisSet = ScatterSet.project_out(isl::dim::set, 1, Dims - 1);
auto ThisFirst = ThisSet.lexmin();
auto ScatterFirst = ThisFirst.add_dims(isl::dim::set, Dims - 1);
auto SubSchedule = Schedule.intersect_range(ScatterFirst);
SubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
SubSchedule = flattenSchedule(SubSchedule);
unsigned SubDims = getNumScatterDims(SubSchedule);
assert(SubDims >= 1);
auto FirstSubSchedule = scheduleProjectOut(SubSchedule, 1, SubDims - 1);
auto FirstScheduleAff = scheduleExtractDimAff(FirstSubSchedule, 0);
auto RemainingSubSchedule = scheduleProjectOut(SubSchedule, 0, 1);
auto FirstSubScatter = isl::set(FirstSubSchedule.range());
LLVM_DEBUG(dbgs() << "Next step in sequence is:\n " << FirstSubScatter
<< "\n");
if (!isDimBoundedByParameter(FirstSubScatter, 0)) {
LLVM_DEBUG(dbgs() << "Abort; sequence step is not bounded\n");
return {};
}
auto FirstSubScatterMap = isl::map::from_range(FirstSubScatter);
auto PartMin = FirstSubScatterMap.dim_min(0);
auto PartMax = FirstSubScatterMap.dim_max(0);
auto One = isl::pw_aff(isl::set::universe(ParamSpace.set_from_params()),
isl::val::one(IslCtx));
auto PartLen = PartMax.add(PartMin.neg()).add(One);
auto AllPartMin = isl::union_pw_aff(PartMin).pullback(AllDomainsToNull);
auto FirstScheduleAffNormalized = FirstScheduleAff.sub(AllPartMin);
auto AllCounter = isl::union_pw_aff(Counter).pullback(AllDomainsToNull);
auto FirstScheduleAffWithOffset =
FirstScheduleAffNormalized.add(AllCounter);
auto ScheduleWithOffset =
isl::union_map::from(
isl::union_pw_multi_aff(FirstScheduleAffWithOffset))
.flat_range_product(RemainingSubSchedule);
NewSchedule = NewSchedule.unite(ScheduleWithOffset);
ScatterSet = ScatterSet.subtract(ScatterFirst);
Counter = Counter.add(PartLen);
}
LLVM_DEBUG(dbgs() << "Sequence-flatten result is:\n " << NewSchedule
<< "\n");
return NewSchedule;
}
isl::union_map tryFlattenLoop(isl::union_map Schedule) {
assert(getNumScatterDims(Schedule) >= 2);
auto Remaining = scheduleProjectOut(Schedule, 0, 1);
auto SubSchedule = flattenSchedule(Remaining);
unsigned SubDims = getNumScatterDims(SubSchedule);
assert(SubDims >= 1);
auto SubExtent = isl::set(SubSchedule.range());
auto SubExtentDims = unsignedFromIslSize(SubExtent.dim(isl::dim::param));
SubExtent = SubExtent.project_out(isl::dim::param, 0, SubExtentDims);
SubExtent = SubExtent.project_out(isl::dim::set, 1, SubDims - 1);
if (!isDimBoundedByConstant(SubExtent, 0)) {
LLVM_DEBUG(dbgs() << "Abort; dimension not bounded by constant\n");
return {};
}
auto Min = SubExtent.dim_min(0);
LLVM_DEBUG(dbgs() << "Min bound:\n " << Min << "\n");
auto MinVal = getConstant(Min, false, true);
auto Max = SubExtent.dim_max(0);
LLVM_DEBUG(dbgs() << "Max bound:\n " << Max << "\n");
auto MaxVal = getConstant(Max, true, false);
if (MinVal.is_null() || MaxVal.is_null() || MinVal.is_nan() ||
MaxVal.is_nan()) {
LLVM_DEBUG(dbgs() << "Abort; dimension bounds could not be determined\n");
return {};
}
auto FirstSubScheduleAff = scheduleExtractDimAff(SubSchedule, 0);
auto RemainingSubSchedule = scheduleProjectOut(std::move(SubSchedule), 0, 1);
auto LenVal = MaxVal.sub(MinVal).add(1);
auto FirstSubScheduleNormalized = subtract(FirstSubScheduleAff, MinVal);
auto FirstAff = scheduleExtractDimAff(Schedule, 0);
auto Offset = multiply(FirstAff, LenVal);
isl::union_pw_multi_aff Index = FirstSubScheduleNormalized.add(Offset);
auto IndexMap = isl::union_map::from(Index);
auto Result = IndexMap.flat_range_product(RemainingSubSchedule);
LLVM_DEBUG(dbgs() << "Loop-flatten result is:\n " << Result << "\n");
return Result;
}
}
isl::union_map polly::flattenSchedule(isl::union_map Schedule) {
unsigned Dims = getNumScatterDims(Schedule);
LLVM_DEBUG(dbgs() << "Recursive schedule to process:\n " << Schedule
<< "\n");
if (Dims == 0) {
return Schedule;
}
if (Dims == 1)
return Schedule;
if (!isVariableDim(Schedule)) {
LLVM_DEBUG(dbgs() << "Fixed dimension; try sequence flattening\n");
auto NewScheduleSequence = tryFlattenSequence(Schedule);
if (!NewScheduleSequence.is_null())
return NewScheduleSequence;
}
LLVM_DEBUG(dbgs() << "Try loop flattening\n");
auto NewScheduleLoop = tryFlattenLoop(Schedule);
if (!NewScheduleLoop.is_null())
return NewScheduleLoop;
LLVM_DEBUG(dbgs() << "Try sequence flattening again\n");
auto NewScheduleSequence = tryFlattenSequence(Schedule);
if (!NewScheduleSequence.is_null())
return NewScheduleSequence;
return Schedule;
}