//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert MemRef dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "memref-to-spirv-pattern"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Returns the offset of the value in `targetBits` representation.
///
/// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
/// It's assumed to be non-negative.
///
/// When accessing an element in the array treating as having elements of
/// `targetBits`, multiple values are loaded in the same time. The method
/// returns the offset where the `srcIdx` locates in the value. For example, if
/// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
/// located at (x % 4) * 8. Because there are four elements in one i32, and one
/// element has 8 bits.
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
                                  int targetBits, OpBuilder &builder) {
  assert(targetBits % sourceBits == 0);
  IntegerType targetType = builder.getIntegerType(targetBits);
  IntegerAttr idxAttr =
      builder.getIntegerAttr(targetType, targetBits / sourceBits);
  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
  IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
  auto srcBitsValue =
      builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
  auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
  return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
}

/// Returns an adjusted spirv::AccessChainOp. Based on the
/// extension/capabilities, certain integer bitwidths `sourceBits` might not be
/// supported. During conversion if a memref of an unsupported type is used,
/// load/stores to this memref need to be modified to use a supported higher
/// bitwidth `targetBits` and extracting the required bits. For an accessing a
/// 1D array (spv.array or spv.rt_array), the last index is modified to load the
/// bits needed. The extraction of the actual bits needed are handled
/// separately. Note that this only works for a 1-D tensor.
static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
                                          spirv::AccessChainOp op,
                                          int sourceBits, int targetBits,
                                          OpBuilder &builder) {
  assert(targetBits % sourceBits == 0);
  const auto loc = op.getLoc();
  IntegerType targetType = builder.getIntegerType(targetBits);
  IntegerAttr attr =
      builder.getIntegerAttr(targetType, targetBits / sourceBits);
  auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
  auto lastDim = op->getOperand(op.getNumOperands() - 1);
  auto indices = llvm::to_vector<4>(op.indices());
  // There are two elements if this is a 1-D tensor.
  assert(indices.size() == 2);
  indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
  Type t = typeConverter.convertType(op.component_ptr().getType());
  return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
}

/// Returns the shifted `targetBits`-bit value with the given offset.
static Value shiftValue(Location loc, Value value, Value offset, Value mask,
                        int targetBits, OpBuilder &builder) {
  Type targetType = builder.getIntegerType(targetBits);
  Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
  return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
                                                   offset);
}

/// Returns true if the allocations of type `t` can be lowered to SPIR-V.
static bool isAllocationSupported(MemRefType t) {
  // Currently only support workgroup local memory allocations with static
  // shape and int or float or vector of int or float element type.
  if (!(t.hasStaticShape() &&
        SPIRVTypeConverter::getMemorySpaceForStorageClass(
            spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt()))
    return false;
  Type elementType = t.getElementType();
  if (auto vecType = elementType.dyn_cast<VectorType>())
    elementType = vecType.getElementType();
  return elementType.isIntOrFloat();
}

/// Returns the scope to use for atomic operations use for emulating store
/// operations of unsupported integer bitwidths, based on the memref
/// type. Returns None on failure.
static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
  Optional<spirv::StorageClass> storageClass =
      SPIRVTypeConverter::getStorageClassForMemorySpace(
          t.getMemorySpaceAsInt());
  if (!storageClass)
    return {};
  switch (*storageClass) {
  case spirv::StorageClass::StorageBuffer:
    return spirv::Scope::Device;
  case spirv::StorageClass::Workgroup:
    return spirv::Scope::Workgroup;
  default: {
  }
  }
  return {};
}

/// Casts the given `srcInt` into a boolean value.
static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
  if (srcInt.getType().isInteger(1))
    return srcInt;

  auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
  return builder.create<spirv::IEqualOp>(loc, srcInt, one);
}

/// Casts the given `srcBool` into an integer of `dstType`.
static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
                            OpBuilder &builder) {
  assert(srcBool.getType().isInteger(1));
  if (dstType.isInteger(1))
    return srcBool;
  Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
  Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
  return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
}

//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//

// Note that DRR cannot be used for the patterns in this file: we may need to
// convert type along the way, which requires ConversionPattern. DRR generates
// normal RewritePattern.

namespace {

/// Converts an allocation operation to SPIR-V. Currently only supports lowering
/// to Workgroup memory when the size is constant.  Note that this pattern needs
/// to be applied in a pass that runs at least at spv.module scope since it wil
/// ladd global variables into the spv.module.
class AllocOpPattern final : public OpConversionPattern<memref::AllocOp> {
public:
  using OpConversionPattern<memref::AllocOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

/// Removed a deallocation if it is a supported allocation. Currently only
/// removes deallocation if the memory space is workgroup memory.
class DeallocOpPattern final : public OpConversionPattern<memref::DeallocOp> {
public:
  using OpConversionPattern<memref::DeallocOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

/// Converts memref.load to spv.Load.
class IntLoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
  using OpConversionPattern<memref::LoadOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

/// Converts memref.load to spv.Load.
class LoadOpPattern final : public OpConversionPattern<memref::LoadOp> {
public:
  using OpConversionPattern<memref::LoadOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

/// Converts memref.store to spv.Store on integers.
class IntStoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
  using OpConversionPattern<memref::StoreOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

/// Converts memref.store to spv.Store.
class StoreOpPattern final : public OpConversionPattern<memref::StoreOp> {
public:
  using OpConversionPattern<memref::StoreOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override;
};

} // namespace

//===----------------------------------------------------------------------===//
// AllocOp
//===----------------------------------------------------------------------===//

LogicalResult
AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
  MemRefType allocType = operation.getType();
  if (!isAllocationSupported(allocType))
    return operation.emitError("unhandled allocation type");

  // Get the SPIR-V type for the allocation.
  Type spirvType = getTypeConverter()->convertType(allocType);

  // Insert spv.GlobalVariable for this allocation.
  Operation *parent =
      SymbolTable::getNearestSymbolTable(operation->getParentOp());
  if (!parent)
    return failure();
  Location loc = operation.getLoc();
  spirv::GlobalVariableOp varOp;
  {
    OpBuilder::InsertionGuard guard(rewriter);
    Block &entryBlock = *parent->getRegion(0).begin();
    rewriter.setInsertionPointToStart(&entryBlock);
    auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
    std::string varName =
        std::string("__workgroup_mem__") +
        std::to_string(std::distance(varOps.begin(), varOps.end()));
    varOp = rewriter.create<spirv::GlobalVariableOp>(loc, spirvType, varName,
                                                     /*initializer=*/nullptr);
  }

  // Get pointer to global variable at the current scope.
  rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
  return success();
}

//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//

LogicalResult
DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation,
                                  OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
  MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
  if (!isAllocationSupported(deallocType))
    return operation.emitError("unhandled deallocation type");
  rewriter.eraseOp(operation);
  return success();
}

//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//

LogicalResult
IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                                  ConversionPatternRewriter &rewriter) const {
  auto loc = loadOp.getLoc();
  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
  if (!memrefType.getElementType().isSignlessInteger())
    return failure();

  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
  spirv::AccessChainOp accessChainOp =
      spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
                           adaptor.indices(), loc, rewriter);

  if (!accessChainOp)
    return failure();

  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
  bool isBool = srcBits == 1;
  if (isBool)
    srcBits = typeConverter.getOptions().boolNumBits;
  Type pointeeType = typeConverter.convertType(memrefType)
                         .cast<spirv::PointerType>()
                         .getPointeeType();
  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
  Type dstType;
  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
    dstType = arrayType.getElementType();
  else
    dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();

  int dstBits = dstType.getIntOrFloatBitWidth();
  assert(dstBits % srcBits == 0);

  // If the rewrited load op has the same bit width, use the loading value
  // directly.
  if (srcBits == dstBits) {
    Value loadVal =
        rewriter.create<spirv::LoadOp>(loc, accessChainOp.getResult());
    if (isBool)
      loadVal = castIntNToBool(loc, loadVal, rewriter);
    rewriter.replaceOp(loadOp, loadVal);
    return success();
  }

  // Assume that getElementPtr() works linearizely. If it's a scalar, the method
  // still returns a linearized accessing. If the accessing is not linearized,
  // there will be offset issues.
  assert(accessChainOp.indices().size() == 2);
  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                   srcBits, dstBits, rewriter);
  Value spvLoadOp = rewriter.create<spirv::LoadOp>(
      loc, dstType, adjustedPtr,
      loadOp->getAttrOfType<spirv::MemoryAccessAttr>(
          spirv::attributeName<spirv::MemoryAccess>()),
      loadOp->getAttrOfType<IntegerAttr>("alignment"));

  // Shift the bits to the rightmost.
  // ____XXXX________ -> ____________XXXX
  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
  Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
      loc, spvLoadOp.getType(), spvLoadOp, offset);

  // Apply the mask to extract corresponding bits.
  Value mask = rewriter.create<spirv::ConstantOp>(
      loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
  result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);

  // Apply sign extension on the loading value unconditionally. The signedness
  // semantic is carried in the operator itself, we relies other pattern to
  // handle the casting.
  IntegerAttr shiftValueAttr =
      rewriter.getIntegerAttr(dstType, dstBits - srcBits);
  Value shiftValue =
      rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
  result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
                                                      shiftValue);
  result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
                                                          shiftValue);

  if (isBool) {
    dstType = typeConverter.convertType(loadOp.getType());
    mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter);
    result = rewriter.create<spirv::IEqualOp>(loc, result, mask);
  } else if (result.getType().getIntOrFloatBitWidth() !=
             static_cast<unsigned>(dstBits)) {
    result = rewriter.create<spirv::SConvertOp>(loc, dstType, result);
  }
  rewriter.replaceOp(loadOp, result);

  assert(accessChainOp.use_empty());
  rewriter.eraseOp(accessChainOp);

  return success();
}

LogicalResult
LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
                               ConversionPatternRewriter &rewriter) const {
  auto memrefType = loadOp.memref().getType().cast<MemRefType>();
  if (memrefType.getElementType().isSignlessInteger())
    return failure();
  auto loadPtr = spirv::getElementPtr(
      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
      adaptor.indices(), loadOp.getLoc(), rewriter);

  if (!loadPtr)
    return failure();

  rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
  return success();
}

LogicalResult
IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
  if (!memrefType.getElementType().isSignlessInteger())
    return failure();

  auto loc = storeOp.getLoc();
  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
  spirv::AccessChainOp accessChainOp =
      spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(),
                           adaptor.indices(), loc, rewriter);

  if (!accessChainOp)
    return failure();

  int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();

  bool isBool = srcBits == 1;
  if (isBool)
    srcBits = typeConverter.getOptions().boolNumBits;

  Type pointeeType = typeConverter.convertType(memrefType)
                         .cast<spirv::PointerType>()
                         .getPointeeType();
  Type structElemType = pointeeType.cast<spirv::StructType>().getElementType(0);
  Type dstType;
  if (auto arrayType = structElemType.dyn_cast<spirv::ArrayType>())
    dstType = arrayType.getElementType();
  else
    dstType = structElemType.cast<spirv::RuntimeArrayType>().getElementType();

  int dstBits = dstType.getIntOrFloatBitWidth();
  assert(dstBits % srcBits == 0);

  if (srcBits == dstBits) {
    Value storeVal = adaptor.value();
    if (isBool)
      storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
    rewriter.replaceOpWithNewOp<spirv::StoreOp>(
        storeOp, accessChainOp.getResult(), storeVal);
    return success();
  }

  // Since there are multi threads in the processing, the emulation will be done
  // with atomic operations. E.g., if the storing value is i8, rewrite the
  // StoreOp to
  // 1) load a 32-bit integer
  // 2) clear 8 bits in the loading value
  // 3) store 32-bit value back
  // 4) load a 32-bit integer
  // 5) modify 8 bits in the loading value
  // 6) store 32-bit value back
  // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
  // 4 to step 6 are done by AtomicOr as another atomic step.
  assert(accessChainOp.indices().size() == 2);
  Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
  Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);

  // Create a mask to clear the destination. E.g., if it is the second i8 in
  // i32, 0xFFFF00FF is created.
  Value mask = rewriter.create<spirv::ConstantOp>(
      loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
  Value clearBitsMask =
      rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
  clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);

  Value storeVal = adaptor.value();
  if (isBool)
    storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter);
  storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter);
  Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
                                                   srcBits, dstBits, rewriter);
  Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
  if (!scope)
    return failure();
  Value result = rewriter.create<spirv::AtomicAndOp>(
      loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
      clearBitsMask);
  result = rewriter.create<spirv::AtomicOrOp>(
      loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
      storeVal);

  // The AtomicOrOp has no side effect. Since it is already inserted, we can
  // just remove the original StoreOp. Note that rewriter.replaceOp()
  // doesn't work because it only accepts that the numbers of result are the
  // same.
  rewriter.eraseOp(storeOp);

  assert(accessChainOp.use_empty());
  rewriter.eraseOp(accessChainOp);

  return success();
}

LogicalResult
StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
                                ConversionPatternRewriter &rewriter) const {
  auto memrefType = storeOp.memref().getType().cast<MemRefType>();
  if (memrefType.getElementType().isSignlessInteger())
    return failure();
  auto storePtr = spirv::getElementPtr(
      *getTypeConverter<SPIRVTypeConverter>(), memrefType, adaptor.memref(),
      adaptor.indices(), storeOp.getLoc(), rewriter);

  if (!storePtr)
    return failure();

  rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
                                              adaptor.value());
  return success();
}

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                   RewritePatternSet &patterns) {
  patterns.add<AllocOpPattern, DeallocOpPattern, IntLoadOpPattern,
               IntStoreOpPattern, LoadOpPattern, StoreOpPattern>(
      typeConverter, patterns.getContext());
}
} // namespace mlir
