#include "mlir/ExecutionEngine/SparseTensor/File.h"
#include <cctype>
#include <cstring>
using namespace mlir::sparse_tensor;
void SparseTensorReader::openFile() {
if (file) {
fprintf(stderr, "Already opened file %s\n", filename);
exit(1);
}
file = fopen(filename, "r");
if (!file) {
fprintf(stderr, "Cannot find file %s\n", filename);
exit(1);
}
}
void SparseTensorReader::closeFile() {
if (file) {
fclose(file);
file = nullptr;
}
}
void SparseTensorReader::readLine() {
if (!fgets(line, kColWidth, file)) {
fprintf(stderr, "Cannot read next line of %s\n", filename);
exit(1);
}
}
void SparseTensorReader::readHeader() {
assert(file && "Attempt to readHeader() before openFile()");
if (strstr(filename, ".mtx")) {
readMMEHeader();
} else if (strstr(filename, ".tns")) {
readExtFROSTTHeader();
} else {
fprintf(stderr, "Unknown format %s\n", filename);
exit(1);
}
assert(isValid() && "Failed to read the header");
}
void SparseTensorReader::assertMatchesShape(uint64_t rank,
const uint64_t *shape) const {
assert(rank == getRank() && "Rank mismatch");
for (uint64_t r = 0; r < rank; r++)
assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
"Dimension size mismatch");
}
bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
switch (valueKind_) {
case ValueKind::kInvalid:
assert(false && "Must readHeader() before calling canReadAs()");
return false;
case ValueKind::kPattern:
return true;
case ValueKind::kInteger:
return isRealPrimaryType(valTy);
case ValueKind::kReal:
return isFloatingPrimaryType(valTy);
case ValueKind::kComplex:
return isComplexPrimaryType(valTy);
case ValueKind::kUndefined:
return isRealPrimaryType(valTy);
}
fprintf(stderr, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_));
return false;
}
static inline void toLower(char *token) {
for (char *c = token; *c; c++)
*c = tolower(*c);
}
static inline bool streq(const char *lhs, const char *rhs) {
return strcmp(lhs, rhs) == 0;
}
static inline bool strne(const char *lhs, const char *rhs) {
return strcmp(lhs, rhs);
}
void SparseTensorReader::readMMEHeader() {
char header[64];
char object[64];
char format[64];
char field[64];
char symmetry[64];
if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
symmetry) != 5) {
fprintf(stderr, "Corrupt header in %s\n", filename);
exit(1);
}
toLower(header);
toLower(object);
toLower(format);
toLower(field);
toLower(symmetry);
if (streq(field, "pattern")) {
valueKind_ = ValueKind::kPattern;
} else if (streq(field, "real")) {
valueKind_ = ValueKind::kReal;
} else if (streq(field, "integer")) {
valueKind_ = ValueKind::kInteger;
} else if (streq(field, "complex")) {
valueKind_ = ValueKind::kComplex;
} else {
fprintf(stderr, "Unexpected header field value in %s\n", filename);
exit(1);
}
isSymmetric_ = streq(symmetry, "symmetric");
if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
strne(format, "coordinate") ||
(strne(symmetry, "general") && !isSymmetric_)) {
fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
exit(1);
}
while (true) {
readLine();
if (line[0] != '%')
break;
}
idata[0] = 2;
if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
idata + 1) != 3) {
fprintf(stderr, "Cannot find size in %s\n", filename);
exit(1);
}
}
void SparseTensorReader::readExtFROSTTHeader() {
while (true) {
readLine();
if (line[0] != '#')
break;
}
if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
fprintf(stderr, "Cannot find metadata in %s\n", filename);
exit(1);
}
for (uint64_t r = 0; r < idata[0]; r++) {
if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
fprintf(stderr, "Cannot find dimension size %s\n", filename);
exit(1);
}
}
readLine();
valueKind_ = ValueKind::kUndefined;
}