//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "Utils/CodegenUtils.h"

#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"

#include <optional>

using namespace mlir;
using namespace sparse_tensor;

namespace {

//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//

static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
  MLIRContext *ctx = tp.getContext();
  auto enc = tp.getEncoding();
  const Level lvlRank = enc.getLvlRank();

  SmallVector<Type, 4> result;
  // TODO: how can we get the lowering type for index type in the later pipeline
  // to be consistent? LLVM::StructureType does not allow index fields.
  auto sizeType = IntegerType::get(tp.getContext(), 64);
  auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
  auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
                                           getNumDataFieldsFromEncoding(enc));
  result.push_back(lvlSizes);
  result.push_back(memSizes);

  if (enc.isSlice()) {
    // Extra fields are required for the slice information.
    auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
    auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);

    result.push_back(dimOffset);
    result.push_back(dimStride);
  }

  return result;
}

static Type convertSpecifier(StorageSpecifierType tp) {
  return LLVM::LLVMStructType::getLiteral(tp.getContext(),
                                          getSpecifierFields(tp));
}

//===----------------------------------------------------------------------===//
// Specifier struct builder.
//===----------------------------------------------------------------------===//

constexpr uint64_t kLvlSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
constexpr uint64_t kDimOffsetPosInSpecifier = 2;
constexpr uint64_t kDimStridePosInSpecifier = 3;

class SpecifierStructBuilder : public StructBuilder {
private:
  Value extractField(OpBuilder &builder, Location loc,
                     ArrayRef<int64_t> indices) const {
    return genCast(builder, loc,
                   builder.create<LLVM::ExtractValueOp>(loc, value, indices),
                   builder.getIndexType());
  }

  void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
                   Value v) {
    value = builder.create<LLVM::InsertValueOp>(
        loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
        indices);
  }

public:
  explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
    assert(value);
  }

  // Undef value for dimension sizes, all zero value for memory sizes.
  static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
                            Value source);

  Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
  void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);

  Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
  void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
                    Value size);

  Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
  void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
                    Value size);

  Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
  void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
                  Value size);

  Value memSizeArray(OpBuilder &builder, Location loc) const;
  void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
};

Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
                                           Type structType, Value source) {
  Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
  SpecifierStructBuilder md(metaData);
  if (!source) {
    auto memSizeArrayType =
        cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
                                      .getBody()[kMemSizePosInSpecifier]);

    Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
    // Fill memSizes array with zero.
    for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
      md.setMemSize(builder, loc, i, zero);
  } else {
    // We copy non-slice information (memory sizes array) from source
    SpecifierStructBuilder sourceMd(source);
    md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
  }
  return md;
}

/// Builds IR extracting the pos-th offset from the descriptor.
Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
                                        Dimension dim) const {
  return extractField(
      builder, loc,
      ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
}

/// Builds IR inserting the pos-th offset into the descriptor.
void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
                                          Dimension dim, Value size) {
  insertField(
      builder, loc,
      ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
      size);
}

/// Builds IR extracting the `lvl`-th level-size from the descriptor.
Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
                                      Level lvl) const {
  // This static_cast makes the narrowing of `lvl` explicit, as required
  // by the braces notation for the ctor.
  return extractField(
      builder, loc,
      ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)});
}

/// Builds IR inserting the `lvl`-th level-size into the descriptor.
void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
                                        Level lvl, Value size) {
  // This static_cast makes the narrowing of `lvl` explicit, as required
  // by the braces notation for the ctor.
  insertField(
      builder, loc,
      ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)},
      size);
}

/// Builds IR extracting the pos-th stride from the descriptor.
Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
                                        Dimension dim) const {
  return extractField(
      builder, loc,
      ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
}

/// Builds IR inserting the pos-th stride into the descriptor.
void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
                                          Dimension dim, Value size) {
  insertField(
      builder, loc,
      ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
      size);
}

/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
                                      FieldIndex fidx) const {
  return extractField(
      builder, loc,
      ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
}

/// Builds IR inserting the `fidx`-th memory-size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
                                        FieldIndex fidx, Value size) {
  insertField(
      builder, loc,
      ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
      size);
}

/// Builds IR extracting the memory size array from the descriptor.
Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
                                           Location loc) const {
  return builder.create<LLVM::ExtractValueOp>(loc, value,
                                              kMemSizePosInSpecifier);
}

/// Builds IR inserting the memory size array into the descriptor.
void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
                                             Value array) {
  value = builder.create<LLVM::InsertValueOp>(loc, value, array,
                                              kMemSizePosInSpecifier);
}

} // namespace

//===----------------------------------------------------------------------===//
// The sparse storage specifier type converter (defined in Passes.h).
//===----------------------------------------------------------------------===//

StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
  addConversion([](Type type) { return type; });
  addConversion(convertSpecifier);
}

//===----------------------------------------------------------------------===//
// Storage specifier conversion rules.
//===----------------------------------------------------------------------===//

template <typename Base, typename SourceOp>
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
public:
  using OpAdaptor = typename SourceOp::Adaptor;
  using OpConversionPattern<SourceOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    SpecifierStructBuilder spec(adaptor.getSpecifier());
    switch (op.getSpecifierKind()) {
    case StorageSpecifierKind::LvlSize: {
      Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
      rewriter.replaceOp(op, v);
      return success();
    }
    case StorageSpecifierKind::DimOffset: {
      Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
      rewriter.replaceOp(op, v);
      return success();
    }
    case StorageSpecifierKind::DimStride: {
      Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
      rewriter.replaceOp(op, v);
      return success();
    }
    case StorageSpecifierKind::CrdMemSize:
    case StorageSpecifierKind::PosMemSize:
    case StorageSpecifierKind::ValMemSize: {
      auto enc = op.getSpecifier().getType().getEncoding();
      StorageLayout layout(enc);
      std::optional<unsigned> lvl;
      if (op.getLevel())
        lvl = (*op.getLevel());
      unsigned idx =
          layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
      Value v = Base::onMemSize(rewriter, op, spec, idx);
      rewriter.replaceOp(op, v);
      return success();
    }
    }
    llvm_unreachable("unrecognized specifer kind");
  }
};

struct StorageSpecifierSetOpConverter
    : public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
                                              SetStorageSpecifierOp> {
  using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;

  static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
                         SpecifierStructBuilder &spec, Level lvl) {
    spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
    return spec;
  }

  static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
                           SpecifierStructBuilder &spec, Dimension d) {
    spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
    return spec;
  }

  static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
                           SpecifierStructBuilder &spec, Dimension d) {
    spec.setDimStride(builder, op.getLoc(), d, op.getValue());
    return spec;
  }

  static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
                         SpecifierStructBuilder &spec, FieldIndex fidx) {
    spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
    return spec;
  }
};

struct StorageSpecifierGetOpConverter
    : public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
                                              GetStorageSpecifierOp> {
  using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;

  static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
                         SpecifierStructBuilder &spec, Level lvl) {
    return spec.lvlSize(builder, op.getLoc(), lvl);
  }

  static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
                           const SpecifierStructBuilder &spec, Dimension d) {
    return spec.dimOffset(builder, op.getLoc(), d);
  }

  static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
                           const SpecifierStructBuilder &spec, Dimension d) {
    return spec.dimStride(builder, op.getLoc(), d);
  }

  static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
                         SpecifierStructBuilder &spec, FieldIndex fidx) {
    return spec.memSize(builder, op.getLoc(), fidx);
  }
};

struct StorageSpecifierInitOpConverter
    : public OpConversionPattern<StorageSpecifierInitOp> {
public:
  using OpConversionPattern::OpConversionPattern;
  LogicalResult
  matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
    rewriter.replaceOp(
        op, SpecifierStructBuilder::getInitValue(
                rewriter, op.getLoc(), llvmType, adaptor.getSource()));
    return success();
  }
};

//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//

void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
                                                  RewritePatternSet &patterns) {
  patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
               StorageSpecifierInitOpConverter>(converter,
                                                patterns.getContext());
}