//===- File.cpp - Reading/writing sparse tensors from/to files ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements reading and writing sparse tensor files.
//
//===----------------------------------------------------------------------===//

#include "mlir/ExecutionEngine/SparseTensor/File.h"

#include <cctype>
#include <cstring>

using namespace mlir::sparse_tensor;

/// Opens the file for reading.
void SparseTensorReader::openFile() {
  if (file) {
    fprintf(stderr, "Already opened file %s\n", filename);
    exit(1);
  }
  file = fopen(filename, "r");
  if (!file) {
    fprintf(stderr, "Cannot find file %s\n", filename);
    exit(1);
  }
}

/// Closes the file.
void SparseTensorReader::closeFile() {
  if (file) {
    fclose(file);
    file = nullptr;
  }
}

/// Attempts to read a line from the file.
void SparseTensorReader::readLine() {
  if (!fgets(line, kColWidth, file)) {
    fprintf(stderr, "Cannot read next line of %s\n", filename);
    exit(1);
  }
}

/// Reads and parses the file's header.
void SparseTensorReader::readHeader() {
  assert(file && "Attempt to readHeader() before openFile()");
  if (strstr(filename, ".mtx")) {
    readMMEHeader();
  } else if (strstr(filename, ".tns")) {
    readExtFROSTTHeader();
  } else {
    fprintf(stderr, "Unknown format %s\n", filename);
    exit(1);
  }
  assert(isValid() && "Failed to read the header");
}

/// Asserts the shape subsumes the actual dimension sizes.  Is only
/// valid after parsing the header.
void SparseTensorReader::assertMatchesShape(uint64_t rank,
                                            const uint64_t *shape) const {
  assert(rank == getRank() && "Rank mismatch");
  for (uint64_t r = 0; r < rank; r++)
    assert((shape[r] == 0 || shape[r] == idata[2 + r]) &&
           "Dimension size mismatch");
}

bool SparseTensorReader::canReadAs(PrimaryType valTy) const {
  switch (valueKind_) {
  case ValueKind::kInvalid:
    assert(false && "Must readHeader() before calling canReadAs()");
    return false; // In case assertions are disabled.
  case ValueKind::kPattern:
    return true;
  case ValueKind::kInteger:
    // When the file is specified to store integer values, we still
    // allow implicitly converting those to floating primary-types.
    return isRealPrimaryType(valTy);
  case ValueKind::kReal:
    // When the file is specified to store real/floating values, then
    // we disallow implicit conversion to integer primary-types.
    return isFloatingPrimaryType(valTy);
  case ValueKind::kComplex:
    // When the file is specified to store complex values, then we
    // require a complex primary-type.
    return isComplexPrimaryType(valTy);
  case ValueKind::kUndefined:
    // The "extended" FROSTT format doesn't specify a ValueKind.
    // So we allow implicitly converting the stored values to both
    // integer and floating primary-types.
    return isRealPrimaryType(valTy);
  }
  fprintf(stderr, "Unknown ValueKind: %d\n", static_cast<uint8_t>(valueKind_));
  return false;
}

/// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
static inline void toLower(char *token) {
  for (char *c = token; *c; c++)
    *c = tolower(*c);
}

/// Idiomatic name for checking string equality.
static inline bool streq(const char *lhs, const char *rhs) {
  return strcmp(lhs, rhs) == 0;
}

/// Idiomatic name for checking string inequality.
static inline bool strne(const char *lhs, const char *rhs) {
  return strcmp(lhs, rhs); // aka `!= 0`
}

/// Read the MME header of a general sparse matrix of type real.
void SparseTensorReader::readMMEHeader() {
  char header[64];
  char object[64];
  char format[64];
  char field[64];
  char symmetry[64];
  // Read header line.
  if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
             symmetry) != 5) {
    fprintf(stderr, "Corrupt header in %s\n", filename);
    exit(1);
  }
  // Convert all to lowercase up front (to avoid accidental redundancy).
  toLower(header);
  toLower(object);
  toLower(format);
  toLower(field);
  toLower(symmetry);
  // Process `field`, which specify pattern or the data type of the values.
  if (streq(field, "pattern")) {
    valueKind_ = ValueKind::kPattern;
  } else if (streq(field, "real")) {
    valueKind_ = ValueKind::kReal;
  } else if (streq(field, "integer")) {
    valueKind_ = ValueKind::kInteger;
  } else if (streq(field, "complex")) {
    valueKind_ = ValueKind::kComplex;
  } else {
    fprintf(stderr, "Unexpected header field value in %s\n", filename);
    exit(1);
  }
  // Set properties.
  isSymmetric_ = streq(symmetry, "symmetric");
  // Make sure this is a general sparse matrix.
  if (strne(header, "%%matrixmarket") || strne(object, "matrix") ||
      strne(format, "coordinate") ||
      (strne(symmetry, "general") && !isSymmetric_)) {
    fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
    exit(1);
  }
  // Skip comments.
  while (true) {
    readLine();
    if (line[0] != '%')
      break;
  }
  // Next line contains M N NNZ.
  idata[0] = 2; // rank
  if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
             idata + 1) != 3) {
    fprintf(stderr, "Cannot find size in %s\n", filename);
    exit(1);
  }
}

/// Read the "extended" FROSTT header. Although not part of the documented
/// format, we assume that the file starts with optional comments followed
/// by two lines that define the rank, the number of nonzeros, and the
/// dimensions sizes (one per rank) of the sparse tensor.
void SparseTensorReader::readExtFROSTTHeader() {
  // Skip comments.
  while (true) {
    readLine();
    if (line[0] != '#')
      break;
  }
  // Next line contains RANK and NNZ.
  if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
    fprintf(stderr, "Cannot find metadata in %s\n", filename);
    exit(1);
  }
  // Followed by a line with the dimension sizes (one per rank).
  for (uint64_t r = 0; r < idata[0]; r++) {
    if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
      fprintf(stderr, "Cannot find dimension size %s\n", filename);
      exit(1);
    }
  }
  readLine(); // end of line
  // The FROSTT format does not define the data type of the nonzero elements.
  valueKind_ = ValueKind::kUndefined;
}