#include "mlir/ExecutionEngine/SparseTensorUtils.h"
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <functional>
#include <iostream>
#include <limits>
#include <numeric>
namespace {
static constexpr int kColWidth = 1025;
static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
assert((lhs == 0 || rhs <= std::numeric_limits<uint64_t>::max() / lhs) &&
"Integer overflow");
return lhs * rhs;
}
#define FATAL(...) \
do { \
fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__); \
exit(1); \
} while (0)
static inline void
assertPermutedSizesMatchShape(const std::vector<uint64_t> &dimSizes,
uint64_t rank, const uint64_t *perm,
const uint64_t *shape) {
assert(perm && shape);
assert(rank == dimSizes.size() && "Rank mismatch");
for (uint64_t r = 0; r < rank; r++)
assert((shape[r] == 0 || shape[r] == dimSizes[perm[r]]) &&
"Dimension size mismatch");
}
template <typename V>
struct Element final {
Element(uint64_t *ind, V val) : indices(ind), value(val){};
uint64_t *indices;
V value;
};
template <typename V>
using ElementConsumer =
const std::function<void(const std::vector<uint64_t> &, V)> &;
template <typename V>
struct SparseTensorCOO final {
public:
SparseTensorCOO(const std::vector<uint64_t> &dimSizes, uint64_t capacity)
: dimSizes(dimSizes) {
if (capacity) {
elements.reserve(capacity);
indices.reserve(capacity * getRank());
}
}
void add(const std::vector<uint64_t> &ind, V val) {
assert(!iteratorLocked && "Attempt to add() after startIterator()");
uint64_t *base = indices.data();
uint64_t size = indices.size();
uint64_t rank = getRank();
assert(ind.size() == rank && "Element rank mismatch");
for (uint64_t r = 0; r < rank; r++) {
assert(ind[r] < dimSizes[r] && "Index is too large for the dimension");
indices.push_back(ind[r]);
}
uint64_t *newBase = indices.data();
if (newBase != base) {
for (uint64_t i = 0, n = elements.size(); i < n; i++)
elements[i].indices = newBase + (elements[i].indices - base);
base = newBase;
}
elements.emplace_back(base + size, val);
}
void sort() {
assert(!iteratorLocked && "Attempt to sort() after startIterator()");
uint64_t rank = getRank();
std::sort(elements.begin(), elements.end(),
[rank](const Element<V> &e1, const Element<V> &e2) {
for (uint64_t r = 0; r < rank; r++) {
if (e1.indices[r] == e2.indices[r])
continue;
return e1.indices[r] < e2.indices[r];
}
return false;
});
}
uint64_t getRank() const { return dimSizes.size(); }
const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
const std::vector<Element<V>> &getElements() const { return elements; }
void startIterator() {
iteratorLocked = true;
iteratorPos = 0;
}
const Element<V> *getNext() {
assert(iteratorLocked && "Attempt to getNext() before startIterator()");
if (iteratorPos < elements.size())
return &(elements[iteratorPos++]);
iteratorLocked = false;
return nullptr;
}
static SparseTensorCOO<V> *newSparseTensorCOO(uint64_t rank,
const uint64_t *dimSizes,
const uint64_t *perm,
uint64_t capacity = 0) {
std::vector<uint64_t> permsz(rank);
for (uint64_t r = 0; r < rank; r++) {
assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
permsz[perm[r]] = dimSizes[r];
}
return new SparseTensorCOO<V>(permsz, capacity);
}
private:
const std::vector<uint64_t> dimSizes;
std::vector<Element<V>> elements;
std::vector<uint64_t> indices;
bool iteratorLocked = false;
unsigned iteratorPos = 0;
};
template <typename V>
class SparseTensorEnumeratorBase;
#define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
class SparseTensorStorageBase {
public:
SparseTensorStorageBase(const std::vector<uint64_t> &dimSizes,
const uint64_t *perm, const DimLevelType *sparsity)
: dimSizes(dimSizes), rev(getRank()),
dimTypes(sparsity, sparsity + getRank()) {
assert(perm && sparsity);
const uint64_t rank = getRank();
assert(rank > 0 && "Trivial shape is unsupported");
for (uint64_t r = 0; r < rank; r++) {
assert(dimSizes[r] > 0 && "Dimension size zero has trivial storage");
assert((dimTypes[r] == DimLevelType::kDense ||
dimTypes[r] == DimLevelType::kCompressed) &&
"Unsupported DimLevelType");
}
for (uint64_t r = 0; r < rank; r++)
rev[perm[r]] = r;
}
virtual ~SparseTensorStorageBase() = default;
uint64_t getRank() const { return dimSizes.size(); }
const std::vector<uint64_t> &getDimSizes() const { return dimSizes; }
uint64_t getDimSize(uint64_t d) const {
assert(d < getRank());
return dimSizes[d];
}
const std::vector<uint64_t> &getRev() const { return rev; }
const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
bool isCompressedDim(uint64_t d) const {
assert(d < getRank());
return (dimTypes[d] == DimLevelType::kCompressed);
}
#define DECL_NEWENUMERATOR(VNAME, V) \
virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t, \
const uint64_t *) const { \
FATAL_PIV("newEnumerator" #VNAME); \
}
FOREVERY_V(DECL_NEWENUMERATOR)
#undef DECL_NEWENUMERATOR
#define DECL_GETPOINTERS(PNAME, P) \
virtual void getPointers(std::vector<P> **, uint64_t) { \
FATAL_PIV("getPointers" #PNAME); \
}
FOREVERY_FIXED_O(DECL_GETPOINTERS)
#undef DECL_GETPOINTERS
#define DECL_GETINDICES(INAME, I) \
virtual void getIndices(std::vector<I> **, uint64_t) { \
FATAL_PIV("getIndices" #INAME); \
}
FOREVERY_FIXED_O(DECL_GETINDICES)
#undef DECL_GETINDICES
#define DECL_GETVALUES(VNAME, V) \
virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
FOREVERY_V(DECL_GETVALUES)
#undef DECL_GETVALUES
#define DECL_LEXINSERT(VNAME, V) \
virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
FOREVERY_V(DECL_LEXINSERT)
#undef DECL_LEXINSERT
#define DECL_EXPINSERT(VNAME, V) \
virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) { \
FATAL_PIV("expInsert" #VNAME); \
}
FOREVERY_V(DECL_EXPINSERT)
#undef DECL_EXPINSERT
virtual void endInsert() = 0;
protected:
SparseTensorStorageBase(const SparseTensorStorageBase &) = default;
SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
private:
const std::vector<uint64_t> dimSizes;
std::vector<uint64_t> rev;
const std::vector<DimLevelType> dimTypes;
};
#undef FATAL_PIV
template <typename P, typename I, typename V>
class SparseTensorEnumerator;
template <typename P, typename I, typename V>
class SparseTensorStorage final : public SparseTensorStorageBase {
SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
const uint64_t *perm, const DimLevelType *sparsity)
: SparseTensorStorageBase(dimSizes, perm, sparsity), pointers(getRank()),
indices(getRank()), idx(getRank()) {}
public:
SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
const uint64_t *perm, const DimLevelType *sparsity,
SparseTensorCOO<V> *coo)
: SparseTensorStorage(dimSizes, perm, sparsity) {
bool allDense = true;
uint64_t sz = 1;
for (uint64_t r = 0, rank = getRank(); r < rank; r++) {
if (isCompressedDim(r)) {
pointers[r].reserve(sz + 1);
pointers[r].push_back(0);
indices[r].reserve(sz);
sz = 1;
allDense = false;
} else {
sz = checkedMul(sz, getDimSizes()[r]);
}
}
if (coo) {
assert(coo->getDimSizes() == getDimSizes() && "Tensor size mismatch");
coo->sort();
const std::vector<Element<V>> &elements = coo->getElements();
uint64_t nnz = elements.size();
values.reserve(nnz);
fromCOO(elements, 0, nnz, 0);
} else if (allDense) {
values.resize(sz, 0);
}
}
SparseTensorStorage(const std::vector<uint64_t> &dimSizes,
const uint64_t *perm, const DimLevelType *sparsity,
const SparseTensorStorageBase &tensor);
~SparseTensorStorage() final = default;
void getPointers(std::vector<P> **out, uint64_t d) final {
assert(d < getRank());
*out = &pointers[d];
}
void getIndices(std::vector<I> **out, uint64_t d) final {
assert(d < getRank());
*out = &indices[d];
}
void getValues(std::vector<V> **out) final { *out = &values; }
void lexInsert(const uint64_t *cursor, V val) final {
uint64_t diff = 0;
uint64_t top = 0;
if (!values.empty()) {
diff = lexDiff(cursor);
endPath(diff + 1);
top = idx[diff] + 1;
}
insPath(cursor, diff, top, val);
}
void expInsert(uint64_t *cursor, V *values, bool *filled, uint64_t *added,
uint64_t count) final {
if (count == 0)
return;
std::sort(added, added + count);
const uint64_t lastDim = getRank() - 1;
uint64_t index = added[0];
cursor[lastDim] = index;
lexInsert(cursor, values[index]);
assert(filled[index]);
values[index] = 0;
filled[index] = false;
for (uint64_t i = 1; i < count; i++) {
assert(index < added[i] && "non-lexicographic insertion");
index = added[i];
cursor[lastDim] = index;
insPath(cursor, lastDim, added[i - 1] + 1, values[index]);
assert(filled[index]);
values[index] = 0;
filled[index] = false;
}
}
void endInsert() final {
if (values.empty())
finalizeSegment(0);
else
endPath(0);
}
void newEnumerator(SparseTensorEnumeratorBase<V> **out, uint64_t rank,
const uint64_t *perm) const final {
*out = new SparseTensorEnumerator<P, I, V>(*this, rank, perm);
}
SparseTensorCOO<V> *toCOO(const uint64_t *perm) const {
SparseTensorEnumeratorBase<V> *enumerator;
newEnumerator(&enumerator, getRank(), perm);
SparseTensorCOO<V> *coo =
new SparseTensorCOO<V>(enumerator->permutedSizes(), values.size());
enumerator->forallElements([&coo](const std::vector<uint64_t> &ind, V val) {
coo->add(ind, val);
});
assert(coo->getElements().size() == values.size());
delete enumerator;
return coo;
}
static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
const DimLevelType *sparsity, SparseTensorCOO<V> *coo) {
SparseTensorStorage<P, I, V> *n = nullptr;
if (coo) {
const auto &coosz = coo->getDimSizes();
assertPermutedSizesMatchShape(coosz, rank, perm, shape);
n = new SparseTensorStorage<P, I, V>(coosz, perm, sparsity, coo);
} else {
std::vector<uint64_t> permsz(rank);
for (uint64_t r = 0; r < rank; r++) {
assert(shape[r] > 0 && "Dimension size zero has trivial storage");
permsz[perm[r]] = shape[r];
}
n = new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, coo);
}
return n;
}
static SparseTensorStorage<P, I, V> *
newSparseTensor(uint64_t rank, const uint64_t *shape, const uint64_t *perm,
const DimLevelType *sparsity,
const SparseTensorStorageBase *source) {
assert(source && "Got nullptr for source");
SparseTensorEnumeratorBase<V> *enumerator;
source->newEnumerator(&enumerator, rank, perm);
const auto &permsz = enumerator->permutedSizes();
assertPermutedSizesMatchShape(permsz, rank, perm, shape);
auto *tensor =
new SparseTensorStorage<P, I, V>(permsz, perm, sparsity, *source);
delete enumerator;
return tensor;
}
private:
void appendPointer(uint64_t d, uint64_t pos, uint64_t count = 1) {
assert(isCompressedDim(d));
assert(pos <= std::numeric_limits<P>::max() &&
"Pointer value is too large for the P-type");
pointers[d].insert(pointers[d].end(), count, static_cast<P>(pos));
}
void appendIndex(uint64_t d, uint64_t full, uint64_t i) {
if (isCompressedDim(d)) {
assert(i <= std::numeric_limits<I>::max() &&
"Index value is too large for the I-type");
indices[d].push_back(static_cast<I>(i));
} else {
assert(i >= full && "Index was already filled");
if (i == full)
return;
if (d + 1 == getRank())
values.insert(values.end(), i - full, 0);
else
finalizeSegment(d + 1, 0, i - full);
}
}
void writeIndex(uint64_t d, uint64_t pos, uint64_t i) {
assert(isCompressedDim(d));
assert(pos < indices[d].size() && "Index position is out of bounds");
assert(i <= std::numeric_limits<I>::max() &&
"Index value is too large for the I-type");
indices[d][pos] = static_cast<I>(i);
}
uint64_t assembledSize(uint64_t parentSz, uint64_t d) const {
if (isCompressedDim(d))
return pointers[d][parentSz];
return parentSz * getDimSizes()[d];
}
void fromCOO(const std::vector<Element<V>> &elements, uint64_t lo,
uint64_t hi, uint64_t d) {
uint64_t rank = getRank();
assert(d <= rank && hi <= elements.size());
if (d == rank) {
assert(lo < hi);
values.push_back(elements[lo].value);
return;
}
uint64_t full = 0;
while (lo < hi) {
uint64_t i = elements[lo].indices[d];
uint64_t seg = lo + 1;
while (seg < hi && elements[seg].indices[d] == i)
seg++;
appendIndex(d, full, i);
full = i + 1;
fromCOO(elements, lo, seg, d + 1);
lo = seg;
}
finalizeSegment(d, full);
}
void finalizeSegment(uint64_t d, uint64_t full = 0, uint64_t count = 1) {
if (count == 0)
return;
if (isCompressedDim(d)) {
appendPointer(d, indices[d].size(), count);
} else {
const uint64_t sz = getDimSizes()[d];
assert(sz >= full && "Segment is overfull");
count = checkedMul(count, sz - full);
if (d + 1 == getRank())
values.insert(values.end(), count, 0);
else
finalizeSegment(d + 1, 0, count);
}
}
void endPath(uint64_t diff) {
uint64_t rank = getRank();
assert(diff <= rank);
for (uint64_t i = 0; i < rank - diff; i++) {
const uint64_t d = rank - i - 1;
finalizeSegment(d, idx[d] + 1);
}
}
void insPath(const uint64_t *cursor, uint64_t diff, uint64_t top, V val) {
uint64_t rank = getRank();
assert(diff < rank);
for (uint64_t d = diff; d < rank; d++) {
uint64_t i = cursor[d];
appendIndex(d, top, i);
top = 0;
idx[d] = i;
}
values.push_back(val);
}
uint64_t lexDiff(const uint64_t *cursor) const {
for (uint64_t r = 0, rank = getRank(); r < rank; r++)
if (cursor[r] > idx[r])
return r;
else
assert(cursor[r] == idx[r] && "non-lexicographic insertion");
assert(0 && "duplication insertion");
return -1u;
}
friend class SparseTensorEnumerator<P, I, V>;
std::vector<std::vector<P>> pointers;
std::vector<std::vector<I>> indices;
std::vector<V> values;
std::vector<uint64_t> idx;
};
template <typename V>
class SparseTensorEnumeratorBase {
public:
SparseTensorEnumeratorBase(const SparseTensorStorageBase &tensor,
uint64_t rank, const uint64_t *perm)
: src(tensor), permsz(src.getRev().size()), reord(getRank()),
cursor(getRank()) {
assert(perm && "Received nullptr for permutation");
assert(rank == getRank() && "Permutation rank mismatch");
const auto &rev = src.getRev();
const auto &dimSizes = src.getDimSizes();
for (uint64_t s = 0; s < rank; s++) {
uint64_t t = perm[rev[s]];
reord[s] = t;
permsz[t] = dimSizes[s];
}
}
virtual ~SparseTensorEnumeratorBase() = default;
SparseTensorEnumeratorBase(const SparseTensorEnumeratorBase &) = delete;
SparseTensorEnumeratorBase &
operator=(const SparseTensorEnumeratorBase &) = delete;
uint64_t getRank() const { return permsz.size(); }
const std::vector<uint64_t> &permutedSizes() const { return permsz; }
virtual void forallElements(ElementConsumer<V> yield) = 0;
protected:
const SparseTensorStorageBase &src;
std::vector<uint64_t> permsz;
std::vector<uint64_t> reord;
std::vector<uint64_t> cursor;
};
template <typename P, typename I, typename V>
class SparseTensorEnumerator final : public SparseTensorEnumeratorBase<V> {
using Base = SparseTensorEnumeratorBase<V>;
public:
SparseTensorEnumerator(const SparseTensorStorage<P, I, V> &tensor,
uint64_t rank, const uint64_t *perm)
: Base(tensor, rank, perm) {}
~SparseTensorEnumerator() final = default;
void forallElements(ElementConsumer<V> yield) final {
forallElements(yield, 0, 0);
}
private:
void forallElements(ElementConsumer<V> yield, uint64_t parentPos,
uint64_t d) {
const auto &src =
static_cast<const SparseTensorStorage<P, I, V> &>(this->src);
if (d == Base::getRank()) {
assert(parentPos < src.values.size() &&
"Value position is out of bounds");
yield(this->cursor, src.values[parentPos]);
} else if (src.isCompressedDim(d)) {
const std::vector<P> &pointersD = src.pointers[d];
assert(parentPos + 1 < pointersD.size() &&
"Parent pointer position is out of bounds");
const uint64_t pstart = static_cast<uint64_t>(pointersD[parentPos]);
const uint64_t pstop = static_cast<uint64_t>(pointersD[parentPos + 1]);
const std::vector<I> &indicesD = src.indices[d];
assert(pstop <= indicesD.size() && "Index position is out of bounds");
uint64_t &cursorReordD = this->cursor[this->reord[d]];
for (uint64_t pos = pstart; pos < pstop; pos++) {
cursorReordD = static_cast<uint64_t>(indicesD[pos]);
forallElements(yield, pos, d + 1);
}
} else {
const uint64_t sz = src.getDimSizes()[d];
const uint64_t pstart = parentPos * sz;
uint64_t &cursorReordD = this->cursor[this->reord[d]];
for (uint64_t i = 0; i < sz; i++) {
cursorReordD = i;
forallElements(yield, pstart + i, d + 1);
}
}
}
};
class SparseTensorNNZ final {
public:
SparseTensorNNZ(const std::vector<uint64_t> &dimSizes,
const std::vector<DimLevelType> &sparsity)
: dimSizes(dimSizes), dimTypes(sparsity), nnz(getRank()) {
assert(dimSizes.size() == dimTypes.size() && "Rank mismatch");
bool uncompressed = true;
(void)uncompressed;
uint64_t sz = 1;
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
switch (dimTypes[r]) {
case DimLevelType::kCompressed:
assert(uncompressed &&
"Multiple compressed layers not currently supported");
uncompressed = false;
nnz[r].resize(sz, 0);
break;
case DimLevelType::kDense:
assert(uncompressed &&
"Dense after compressed not currently supported");
break;
case DimLevelType::kSingleton:
break;
}
sz = checkedMul(sz, dimSizes[r]);
}
}
SparseTensorNNZ(const SparseTensorNNZ &) = delete;
SparseTensorNNZ &operator=(const SparseTensorNNZ &) = delete;
uint64_t getRank() const { return dimSizes.size(); }
template <typename V>
void initialize(SparseTensorEnumeratorBase<V> &enumerator) {
assert(enumerator.getRank() == getRank() && "Tensor rank mismatch");
assert(enumerator.permutedSizes() == dimSizes && "Tensor size mismatch");
enumerator.forallElements(
[this](const std::vector<uint64_t> &ind, V) { add(ind); });
}
using NNZConsumer = const std::function<void(uint64_t)> &;
void forallIndices(uint64_t stopDim, NNZConsumer yield) const {
assert(stopDim < getRank() && "Stopping-dimension is out of bounds");
assert(dimTypes[stopDim] == DimLevelType::kCompressed &&
"Cannot look up non-compressed dimensions");
forallIndices(yield, stopDim, 0, 0);
}
private:
void add(const std::vector<uint64_t> &ind) {
uint64_t parentPos = 0;
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
if (dimTypes[r] == DimLevelType::kCompressed)
nnz[r][parentPos]++;
parentPos = parentPos * dimSizes[r] + ind[r];
}
}
void forallIndices(NNZConsumer yield, uint64_t stopDim, uint64_t parentPos,
uint64_t d) const {
assert(d <= stopDim);
if (d == stopDim) {
assert(parentPos < nnz[d].size() && "Cursor is out of range");
yield(nnz[d][parentPos]);
} else {
const uint64_t sz = dimSizes[d];
const uint64_t pstart = parentPos * sz;
for (uint64_t i = 0; i < sz; i++)
forallIndices(yield, stopDim, pstart + i, d + 1);
}
}
const std::vector<uint64_t> &dimSizes;
const std::vector<DimLevelType> &dimTypes;
std::vector<std::vector<uint64_t>> nnz;
};
template <typename P, typename I, typename V>
SparseTensorStorage<P, I, V>::SparseTensorStorage(
const std::vector<uint64_t> &dimSizes, const uint64_t *perm,
const DimLevelType *sparsity, const SparseTensorStorageBase &tensor)
: SparseTensorStorage(dimSizes, perm, sparsity) {
SparseTensorEnumeratorBase<V> *enumerator;
tensor.newEnumerator(&enumerator, getRank(), perm);
{
SparseTensorNNZ nnz(getDimSizes(), getDimTypes());
nnz.initialize(*enumerator);
uint64_t parentSz = 1;
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
if (isCompressedDim(r)) {
pointers[r].reserve(parentSz + 1);
pointers[r].push_back(0);
uint64_t currentPos = 0;
nnz.forallIndices(r, [this, ¤tPos, r](uint64_t n) {
currentPos += n;
appendPointer(r, currentPos);
});
assert(pointers[r].size() == parentSz + 1 &&
"Final pointers size doesn't match allocated size");
}
parentSz = assembledSize(parentSz, r);
if (isCompressedDim(r))
indices[r].resize(parentSz, 0);
}
values.resize(parentSz, 0);
}
enumerator->forallElements([this](const std::vector<uint64_t> &ind, V val) {
uint64_t parentSz = 1, parentPos = 0;
for (uint64_t rank = getRank(), r = 0; r < rank; r++) {
if (isCompressedDim(r)) {
assert(parentPos < parentSz && "Pointers position is out of bounds");
const uint64_t currentPos = pointers[r][parentPos];
pointers[r][parentPos]++;
writeIndex(r, currentPos, ind[r]);
parentPos = currentPos;
} else {
parentPos = parentPos * getDimSizes()[r] + ind[r];
}
parentSz = assembledSize(parentSz, r);
}
assert(parentPos < values.size() && "Value position is out of bounds");
values[parentPos] = val;
});
delete enumerator;
for (uint64_t parentSz = 1, rank = getRank(), r = 0; r < rank; r++) {
if (isCompressedDim(r)) {
assert(parentSz == pointers[r].size() - 1 &&
"Actual pointers size doesn't match the expected size");
assert(pointers[r][parentSz - 1] == pointers[r][parentSz] &&
"Pointers got corrupted");
for (uint64_t n = 0; n < parentSz; n++) {
const uint64_t parentPos = parentSz - n;
pointers[r][parentPos] = pointers[r][parentPos - 1];
}
pointers[r][0] = 0;
}
parentSz = assembledSize(parentSz, r);
}
}
static char *toLower(char *token) {
for (char *c = token; *c; c++)
*c = tolower(*c);
return token;
}
class SparseTensorFile final {
public:
enum class ValueKind {
kInvalid = 0,
kPattern = 1,
kReal = 2,
kInteger = 3,
kComplex = 4,
kUndefined = 5
};
explicit SparseTensorFile(char *filename) : filename(filename) {
assert(filename && "Received nullptr for filename");
}
SparseTensorFile(const SparseTensorFile &) = delete;
SparseTensorFile &operator=(const SparseTensorFile &) = delete;
~SparseTensorFile() { closeFile(); }
void openFile() {
if (file)
FATAL("Already opened file %s\n", filename);
file = fopen(filename, "r");
if (!file)
FATAL("Cannot find file %s\n", filename);
}
void closeFile() {
if (file) {
fclose(file);
file = nullptr;
}
}
char *readLine() {
if (fgets(line, kColWidth, file))
return line;
FATAL("Cannot read next line of %s\n", filename);
}
void readHeader() {
assert(file && "Attempt to readHeader() before openFile()");
if (strstr(filename, ".mtx"))
readMMEHeader();
else if (strstr(filename, ".tns"))
readExtFROSTTHeader();
else
FATAL("Unknown format %s\n", filename);
assert(isValid() && "Failed to read the header");
}
ValueKind getValueKind() const { return valueKind_; }
bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
bool isPattern() const {
assert(isValid() && "Attempt to isPattern() before readHeader()");
return valueKind_ == ValueKind::kPattern;
}
bool isSymmetric() const {
assert(isValid() && "Attempt to isSymmetric() before readHeader()");
return isSymmetric_;
}
uint64_t getRank() const {
assert(isValid() && "Attempt to getRank() before readHeader()");
return idata[0];
}
uint64_t getNNZ() const {
assert(isValid() && "Attempt to getNNZ() before readHeader()");
return idata[1];
}
const uint64_t *getDimSizes() const { return idata + 2; }
uint64_t getDimSize(uint64_t d) const {
assert(d < getRank());
return idata[2 + d];
}
void assertMatchesShape(uint64_t rank, const uint64_t *shape) const {
assert(rank == getRank() && "Rank mismatch");
for (uint64_t r = 0; r < rank; r++)
assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
"Dimension size mismatch");
}
private:
void readMMEHeader();
void readExtFROSTTHeader();
const char *filename;
FILE *file = nullptr;
ValueKind valueKind_ = ValueKind::kInvalid;
bool isSymmetric_ = false;
uint64_t idata[512];
char line[kColWidth];
};
void SparseTensorFile::readMMEHeader() {
char header[64];
char object[64];
char format[64];
char field[64];
char symmetry[64];
if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
symmetry) != 5)
FATAL("Corrupt header in %s\n", filename);
if (strcmp(toLower(field), "pattern") == 0)
valueKind_ = ValueKind::kPattern;
else if (strcmp(toLower(field), "real") == 0)
valueKind_ = ValueKind::kReal;
else if (strcmp(toLower(field), "integer") == 0)
valueKind_ = ValueKind::kInteger;
else if (strcmp(toLower(field), "complex") == 0)
valueKind_ = ValueKind::kComplex;
else
FATAL("Unexpected header field value in %s\n", filename);
isSymmetric_ = (strcmp(toLower(symmetry), "symmetric") == 0);
if (strcmp(toLower(header), "%%matrixmarket") ||
strcmp(toLower(object), "matrix") ||
strcmp(toLower(format), "coordinate") ||
(strcmp(toLower(symmetry), "general") && !isSymmetric_))
FATAL("Cannot find a general sparse matrix in %s\n", filename);
while (true) {
readLine();
if (line[0] != '%')
break;
}
idata[0] = 2;
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
idata + 1) != 3)
FATAL("Cannot find size in %s\n", filename);
}
void SparseTensorFile::readExtFROSTTHeader() {
while (true) {
readLine();
if (line[0] != '#')
break;
}
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
FATAL("Cannot find metadata in %s\n", filename);
for (uint64_t r = 0; r < idata[0]; r++)
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
FATAL("Cannot find dimension size %s\n", filename);
readLine();
valueKind_ = ValueKind::kUndefined;
}
template <typename T, typename V>
static inline void addValue(T *coo, V value,
const std::vector<uint64_t> indices,
bool is_symmetric_value) {
coo->add(indices, value);
if (is_symmetric_value)
coo->add({indices[1], indices[0]}, value);
}
template <typename V>
static inline void readCOOValue(SparseTensorCOO<std::complex<V>> *coo,
const std::vector<uint64_t> indices,
char **linePtr, bool is_pattern,
bool add_symmetric_value) {
V re = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
V im = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
std::complex<V> value = {re, im};
addValue(coo, value, indices, add_symmetric_value);
}
template <typename V,
typename std::enable_if<
!std::is_same<std::complex<float>, V>::value &&
!std::is_same<std::complex<double>, V>::value>::type * = nullptr>
static void inline readCOOValue(SparseTensorCOO<V> *coo,
const std::vector<uint64_t> indices,
char **linePtr, bool is_pattern,
bool is_symmetric_value) {
double value = is_pattern ? 1.0 : strtod(*linePtr, linePtr);
addValue(coo, value, indices, is_symmetric_value);
}
template <typename V>
static SparseTensorCOO<V> *
openSparseTensorCOO(char *filename, uint64_t rank, const uint64_t *shape,
const uint64_t *perm, PrimaryType valTp) {
SparseTensorFile stfile(filename);
stfile.openFile();
stfile.readHeader();
SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
bool tensorIsInteger =
(valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
(valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
FATAL("Tensor element type %d not compatible with values in file %s\n",
static_cast<int>(valTp), filename);
}
stfile.assertMatchesShape(rank, shape);
uint64_t nnz = stfile.getNNZ();
auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, stfile.getDimSizes(),
perm, nnz);
std::vector<uint64_t> indices(rank);
for (uint64_t k = 0; k < nnz; k++) {
char *linePtr = stfile.readLine();
for (uint64_t r = 0; r < rank; r++) {
uint64_t idx = strtoul(linePtr, &linePtr, 10);
indices[perm[r]] = idx - 1;
}
readCOOValue(coo, indices, &linePtr, stfile.isPattern(),
stfile.isSymmetric() && indices[0] != indices[1]);
}
stfile.closeFile();
return coo;
}
template <typename V>
static void outSparseTensor(void *tensor, void *dest, bool sort) {
assert(tensor && dest);
auto coo = static_cast<SparseTensorCOO<V> *>(tensor);
if (sort)
coo->sort();
char *filename = static_cast<char *>(dest);
auto &dimSizes = coo->getDimSizes();
auto &elements = coo->getElements();
uint64_t rank = coo->getRank();
uint64_t nnz = elements.size();
std::fstream file;
file.open(filename, std::ios_base::out | std::ios_base::trunc);
assert(file.is_open());
file << "; extended FROSTT format\n" << rank << " " << nnz << std::endl;
for (uint64_t r = 0; r < rank - 1; r++)
file << dimSizes[r] << " ";
file << dimSizes[rank - 1] << std::endl;
for (uint64_t i = 0; i < nnz; i++) {
auto &idx = elements[i].indices;
for (uint64_t r = 0; r < rank; r++)
file << (idx[r] + 1) << " ";
file << elements[i].value << std::endl;
}
file.flush();
file.close();
assert(file.good());
}
template <typename V>
static SparseTensorStorage<uint64_t, uint64_t, V> *
toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
uint64_t *indices, uint64_t *perm, uint8_t *sparse) {
const DimLevelType *sparsity = (DimLevelType *)(sparse);
#ifndef NDEBUG
std::vector<uint64_t> order(perm, perm + rank);
std::sort(order.begin(), order.end());
for (uint64_t i = 0; i < rank; ++i)
if (i != order[i])
FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
for (uint64_t i = 0; i < rank; ++i)
if (sparsity[i] != DimLevelType::kDense &&
sparsity[i] != DimLevelType::kCompressed)
FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
#endif
auto *coo = SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm, nse);
std::vector<uint64_t> idx(rank);
for (uint64_t i = 0, base = 0; i < nse; i++) {
for (uint64_t r = 0; r < rank; r++)
idx[perm[r]] = indices[base + r];
coo->add(idx, values[i]);
base += rank;
}
auto *tensor = SparseTensorStorage<uint64_t, uint64_t, V>::newSparseTensor(
rank, shape, perm, sparsity, coo);
delete coo;
return tensor;
}
template <typename V>
static void fromMLIRSparseTensor(void *tensor, uint64_t *pRank, uint64_t *pNse,
uint64_t **pShape, V **pValues,
uint64_t **pIndices) {
assert(tensor);
auto sparseTensor =
static_cast<SparseTensorStorage<uint64_t, uint64_t, V> *>(tensor);
uint64_t rank = sparseTensor->getRank();
std::vector<uint64_t> perm(rank);
std::iota(perm.begin(), perm.end(), 0);
SparseTensorCOO<V> *coo = sparseTensor->toCOO(perm.data());
const std::vector<Element<V>> &elements = coo->getElements();
uint64_t nse = elements.size();
uint64_t *shape = new uint64_t[rank];
for (uint64_t i = 0; i < rank; i++)
shape[i] = coo->getDimSizes()[i];
V *values = new V[nse];
uint64_t *indices = new uint64_t[rank * nse];
for (uint64_t i = 0, base = 0; i < nse; i++) {
values[i] = elements[i].value;
for (uint64_t j = 0; j < rank; j++)
indices[base + j] = elements[i].indices[j];
base += rank;
}
delete coo;
*pRank = rank;
*pNse = nse;
*pShape = shape;
*pValues = values;
*pIndices = indices;
}
}
extern "C" {
#define CASE(p, i, v, P, I, V) \
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
SparseTensorCOO<V> *coo = nullptr; \
if (action <= Action::kFromCOO) { \
if (action == Action::kFromFile) { \
char *filename = static_cast<char *>(ptr); \
coo = openSparseTensorCOO<V>(filename, rank, shape, perm, v); \
} else if (action == Action::kFromCOO) { \
coo = static_cast<SparseTensorCOO<V> *>(ptr); \
} else { \
assert(action == Action::kEmpty); \
} \
auto *tensor = SparseTensorStorage<P, I, V>::newSparseTensor( \
rank, shape, perm, sparsity, coo); \
if (action == Action::kFromFile) \
delete coo; \
return tensor; \
} \
if (action == Action::kSparseToSparse) { \
auto *tensor = static_cast<SparseTensorStorageBase *>(ptr); \
return SparseTensorStorage<P, I, V>::newSparseTensor(rank, shape, perm, \
sparsity, tensor); \
} \
if (action == Action::kEmptyCOO) \
return SparseTensorCOO<V>::newSparseTensorCOO(rank, shape, perm); \
coo = static_cast<SparseTensorStorage<P, I, V> *>(ptr)->toCOO(perm); \
if (action == Action::kToIterator) { \
coo->startIterator(); \
} else { \
assert(action == Action::kToCOO); \
} \
return coo; \
}
#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
static_assert(std::is_same<index_type, uint64_t>::value,
"Expected index_type == uint64_t");
void *
_mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref,
StridedMemRefType<index_type, 1> *sref,
StridedMemRefType<index_type, 1> *pref,
OverheadType ptrTp, OverheadType indTp,
PrimaryType valTp, Action action, void *ptr) {
assert(aref && sref && pref);
assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
pref->strides[0] == 1);
assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
const DimLevelType *sparsity = aref->data + aref->offset;
const index_type *shape = sref->data + sref->offset;
const index_type *perm = pref->data + pref->offset;
uint64_t rank = aref->sizes[0];
if (ptrTp == OverheadType::kIndex)
ptrTp = OverheadType::kU64;
if (indTp == OverheadType::kIndex)
indTp = OverheadType::kU64;
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF64, uint64_t,
uint64_t, double);
CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF64, uint64_t,
uint32_t, double);
CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF64, uint64_t,
uint16_t, double);
CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF64, uint64_t,
uint8_t, double);
CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF64, uint32_t,
uint64_t, double);
CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF64, uint32_t,
uint32_t, double);
CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF64, uint32_t,
uint16_t, double);
CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF64, uint32_t,
uint8_t, double);
CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF64, uint16_t,
uint64_t, double);
CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF64, uint16_t,
uint32_t, double);
CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF64, uint16_t,
uint16_t, double);
CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF64, uint16_t,
uint8_t, double);
CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF64, uint8_t,
uint64_t, double);
CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF64, uint8_t,
uint32_t, double);
CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF64, uint8_t,
uint16_t, double);
CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF64, uint8_t,
uint8_t, double);
CASE(OverheadType::kU64, OverheadType::kU64, PrimaryType::kF32, uint64_t,
uint64_t, float);
CASE(OverheadType::kU64, OverheadType::kU32, PrimaryType::kF32, uint64_t,
uint32_t, float);
CASE(OverheadType::kU64, OverheadType::kU16, PrimaryType::kF32, uint64_t,
uint16_t, float);
CASE(OverheadType::kU64, OverheadType::kU8, PrimaryType::kF32, uint64_t,
uint8_t, float);
CASE(OverheadType::kU32, OverheadType::kU64, PrimaryType::kF32, uint32_t,
uint64_t, float);
CASE(OverheadType::kU32, OverheadType::kU32, PrimaryType::kF32, uint32_t,
uint32_t, float);
CASE(OverheadType::kU32, OverheadType::kU16, PrimaryType::kF32, uint32_t,
uint16_t, float);
CASE(OverheadType::kU32, OverheadType::kU8, PrimaryType::kF32, uint32_t,
uint8_t, float);
CASE(OverheadType::kU16, OverheadType::kU64, PrimaryType::kF32, uint16_t,
uint64_t, float);
CASE(OverheadType::kU16, OverheadType::kU32, PrimaryType::kF32, uint16_t,
uint32_t, float);
CASE(OverheadType::kU16, OverheadType::kU16, PrimaryType::kF32, uint16_t,
uint16_t, float);
CASE(OverheadType::kU16, OverheadType::kU8, PrimaryType::kF32, uint16_t,
uint8_t, float);
CASE(OverheadType::kU8, OverheadType::kU64, PrimaryType::kF32, uint8_t,
uint64_t, float);
CASE(OverheadType::kU8, OverheadType::kU32, PrimaryType::kF32, uint8_t,
uint32_t, float);
CASE(OverheadType::kU8, OverheadType::kU16, PrimaryType::kF32, uint8_t,
uint16_t, float);
CASE(OverheadType::kU8, OverheadType::kU8, PrimaryType::kF32, uint8_t,
uint8_t, float);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kF16, uint64_t, f16);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kBF16, uint64_t, bf16);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kF16, uint32_t, f16);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kBF16, uint32_t, bf16);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kF16, uint16_t, f16);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kBF16, uint16_t, bf16);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kF16, uint8_t, f16);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kBF16, uint8_t, bf16);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI64, uint64_t, int64_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI32, uint64_t, int32_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI16, uint64_t, int16_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kI8, uint64_t, int8_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI64, uint32_t, int64_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI32, uint32_t, int32_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI16, uint32_t, int16_t);
CASE_SECSAME(OverheadType::kU32, PrimaryType::kI8, uint32_t, int8_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI64, uint16_t, int64_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI32, uint16_t, int32_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI16, uint16_t, int16_t);
CASE_SECSAME(OverheadType::kU16, PrimaryType::kI8, uint16_t, int8_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI64, uint8_t, int64_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI32, uint8_t, int32_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI16, uint8_t, int16_t);
CASE_SECSAME(OverheadType::kU8, PrimaryType::kI8, uint8_t, int8_t);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC64, uint64_t, complex64);
CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
static_cast<int>(ptrTp), static_cast<int>(indTp),
static_cast<int>(valTp));
}
#undef CASE
#undef CASE_SECSAME
#define IMPL_SPARSEVALUES(VNAME, V) \
void _mlir_ciface_sparseValues##VNAME(StridedMemRefType<V, 1> *ref, \
void *tensor) { \
assert(ref &&tensor); \
std::vector<V> *v; \
static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
ref->basePtr = ref->data = v->data(); \
ref->offset = 0; \
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
FOREVERY_V(IMPL_SPARSEVALUES)
#undef IMPL_SPARSEVALUES
#define IMPL_GETOVERHEAD(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
index_type d) { \
assert(ref &&tensor); \
std::vector<TYPE> *v; \
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \
ref->basePtr = ref->data = v->data(); \
ref->offset = 0; \
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
#define IMPL_SPARSEPOINTERS(PNAME, P) \
IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
FOREVERY_O(IMPL_SPARSEPOINTERS)
#undef IMPL_SPARSEPOINTERS
#define IMPL_SPARSEINDICES(INAME, I) \
IMPL_GETOVERHEAD(sparseIndices##INAME, I, getIndices)
FOREVERY_O(IMPL_SPARSEINDICES)
#undef IMPL_SPARSEINDICES
#undef IMPL_GETOVERHEAD
#define IMPL_ADDELT(VNAME, V) \
void *_mlir_ciface_addElt##VNAME(void *coo, StridedMemRefType<V, 0> *vref, \
StridedMemRefType<index_type, 1> *iref, \
StridedMemRefType<index_type, 1> *pref) { \
assert(coo &&vref &&iref &&pref); \
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
assert(iref->sizes[0] == pref->sizes[0]); \
const index_type *indx = iref->data + iref->offset; \
const index_type *perm = pref->data + pref->offset; \
uint64_t isize = iref->sizes[0]; \
std::vector<index_type> indices(isize); \
for (uint64_t r = 0; r < isize; r++) \
indices[perm[r]] = indx[r]; \
V *value = vref->data + vref->offset; \
static_cast<SparseTensorCOO<V> *>(coo)->add(indices, *value); \
return coo; \
}
FOREVERY_V(IMPL_ADDELT)
#undef IMPL_ADDELT
#define IMPL_GETNEXT(VNAME, V) \
bool _mlir_ciface_getNext##VNAME(void *coo, \
StridedMemRefType<index_type, 1> *iref, \
StridedMemRefType<V, 0> *vref) { \
assert(coo &&iref &&vref); \
assert(iref->strides[0] == 1); \
index_type *indx = iref->data + iref->offset; \
V *value = vref->data + vref->offset; \
const uint64_t isize = iref->sizes[0]; \
const Element<V> *elem = \
static_cast<SparseTensorCOO<V> *>(coo)->getNext(); \
if (elem == nullptr) \
return false; \
for (uint64_t r = 0; r < isize; r++) \
indx[r] = elem->indices[r]; \
*value = elem->value; \
return true; \
}
FOREVERY_V(IMPL_GETNEXT)
#undef IMPL_GETNEXT
#define IMPL_LEXINSERT(VNAME, V) \
void _mlir_ciface_lexInsert##VNAME(void *tensor, \
StridedMemRefType<index_type, 1> *cref, \
StridedMemRefType<V, 0> *vref) { \
assert(tensor &&cref &&vref); \
assert(cref->strides[0] == 1); \
index_type *cursor = cref->data + cref->offset; \
assert(cursor); \
V *value = vref->data + vref->offset; \
static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
}
FOREVERY_V(IMPL_LEXINSERT)
#undef IMPL_LEXINSERT
#define IMPL_EXPINSERT(VNAME, V) \
void _mlir_ciface_expInsert##VNAME( \
void *tensor, StridedMemRefType<index_type, 1> *cref, \
StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
StridedMemRefType<index_type, 1> *aref, index_type count) { \
assert(tensor &&cref &&vref &&fref &&aref); \
assert(cref->strides[0] == 1); \
assert(vref->strides[0] == 1); \
assert(fref->strides[0] == 1); \
assert(aref->strides[0] == 1); \
assert(vref->sizes[0] == fref->sizes[0]); \
index_type *cursor = cref->data + cref->offset; \
V *values = vref->data + vref->offset; \
bool *filled = fref->data + fref->offset; \
index_type *added = aref->data + aref->offset; \
static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
cursor, values, filled, added, count); \
}
FOREVERY_V(IMPL_EXPINSERT)
#undef IMPL_EXPINSERT
index_type sparseDimSize(void *tensor, index_type d) {
return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
}
void endInsert(void *tensor) {
return static_cast<SparseTensorStorageBase *>(tensor)->endInsert();
}
#define IMPL_OUTSPARSETENSOR(VNAME, V) \
void outSparseTensor##VNAME(void *coo, void *dest, bool sort) { \
return outSparseTensor<V>(coo, dest, sort); \
}
FOREVERY_V(IMPL_OUTSPARSETENSOR)
#undef IMPL_OUTSPARSETENSOR
void delSparseTensor(void *tensor) {
delete static_cast<SparseTensorStorageBase *>(tensor);
}
#define IMPL_DELCOO(VNAME, V) \
void delSparseTensorCOO##VNAME(void *coo) { \
delete static_cast<SparseTensorCOO<V> *>(coo); \
}
FOREVERY_V(IMPL_DELCOO)
#undef IMPL_DELCOO
char *getTensorFilename(index_type id) {
char var[80];
sprintf(var, "TENSOR%" PRIu64, id);
char *env = getenv(var);
if (!env)
FATAL("Environment variable %s is not set\n", var);
return env;
}
void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
assert(out && "Received nullptr for out-parameter");
SparseTensorFile stfile(filename);
stfile.openFile();
stfile.readHeader();
stfile.closeFile();
const uint64_t rank = stfile.getRank();
const uint64_t *dimSizes = stfile.getDimSizes();
out->reserve(rank);
out->assign(dimSizes, dimSizes + rank);
}
#define IMPL_CONVERTTOMLIRSPARSETENSOR(VNAME, V) \
void *convertToMLIRSparseTensor##VNAME( \
uint64_t rank, uint64_t nse, uint64_t *shape, V *values, \
uint64_t *indices, uint64_t *perm, uint8_t *sparse) { \
return toMLIRSparseTensor<V>(rank, nse, shape, values, indices, perm, \
sparse); \
}
FOREVERY_V(IMPL_CONVERTTOMLIRSPARSETENSOR)
#undef IMPL_CONVERTTOMLIRSPARSETENSOR
#define IMPL_CONVERTFROMMLIRSPARSETENSOR(VNAME, V) \
void convertFromMLIRSparseTensor##VNAME(void *tensor, uint64_t *pRank, \
uint64_t *pNse, uint64_t **pShape, \
V **pValues, uint64_t **pIndices) { \
fromMLIRSparseTensor<V>(tensor, pRank, pNse, pShape, pValues, pIndices); \
}
FOREVERY_V(IMPL_CONVERTFROMMLIRSPARSETENSOR)
#undef IMPL_CONVERTFROMMLIRSPARSETENSOR
}
#endif