//===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the NVGPU dialect and its operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::nvgpu;

#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"

void nvgpu::NVGPUDialect::initialize() {
  addTypes<DeviceAsyncTokenType>();
  addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
      >();
}

Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
  // Parse the main keyword for the type.
  StringRef keyword;
  if (parser.parseKeyword(&keyword))
    return Type();
  MLIRContext *context = getContext();
  // Handle 'device async token' types.
  if (keyword == "device.async.token")
    return DeviceAsyncTokenType::get(context);

  parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
  return Type();
}

void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
  TypeSwitch<Type>(type)
      .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
      .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
}
//===----------------------------------------------------------------------===//
// NVGPU_DeviceAsyncCopyOp
//===----------------------------------------------------------------------===//

/// Return true if the last dimension of the MemRefType has unit stride. Also
/// return true for memrefs with no strides.
static bool isLastMemrefDimUnitStride(MemRefType type) {
  int64_t offset;
  SmallVector<int64_t> strides;
  if (failed(getStridesAndOffset(type, strides, offset))) {
    return false;
  }
  return strides.back() == 1;
}

LogicalResult DeviceAsyncCopyOp::verify() {
  auto srcMemref = getSrc().getType().cast<MemRefType>();
  auto dstMemref = getDst().getType().cast<MemRefType>();
  unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
  if (!isLastMemrefDimUnitStride(srcMemref))
    return emitError("source memref most minor dim must have unit stride");
  if (!isLastMemrefDimUnitStride(dstMemref))
    return emitError("destination memref most minor dim must have unit stride");
  if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
    return emitError("destination memref must have memory space ")
           << workgroupAddressSpace;
  if (dstMemref.getElementType() != srcMemref.getElementType())
    return emitError("source and destination must have the same element type");
  if (size_t(srcMemref.getRank()) != getSrcIndices().size())
    return emitOpError() << "expected " << srcMemref.getRank()
                         << " source indices, got " << getSrcIndices().size();
  if (size_t(dstMemref.getRank()) != getDstIndices().size())
    return emitOpError() << "expected " << dstMemref.getRank()
                         << " destination indices, got "
                         << getDstIndices().size();
  return success();
}

//===----------------------------------------------------------------------===//
// NVGPU_MmaSyncOp
//===----------------------------------------------------------------------===//

LogicalResult MmaSyncOp::verify() {

  // Fundamental tensor core mma.sync op
  // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core
  // operation is of shape: 8-by-8-by-128b. F64 is an exception. The
  // verification for mma.sync covering various shapes and data types is based
  // on the fundamental tensor core operionation.
  constexpr int kThreads = 32; // 32 threads per warp
  int64_t shapeM = 8;
  int64_t shapeN = 8;
  int64_t shapeK; // set based on data type (128b for all data types except F64)

  // Number of elements A, B, and C per thread per fundamental tensor core tile
  int64_t numElementA;    // set based on data type (32b except F64)
  int64_t numElementB;    // set based on data type (32b except F64)
  int64_t numElementC{2}; // two accumulator elements per fundamental tile

  // nvgpu.mma.sync vector operands (per thread)
  auto aVector = getMatrixA().getType().cast<VectorType>();
  auto bVector = getMatrixB().getType().cast<VectorType>();
  auto cVector = getMatrixC().getType().cast<VectorType>();

  // vector shapes
  ArrayRef<int64_t> aShape = aVector.getShape();
  ArrayRef<int64_t> bShape = bVector.getShape();
  ArrayRef<int64_t> cShape = cVector.getShape();

  // vector element type
  Type aType = aVector.getElementType();

  // nvgpu.mma.sync shape (per 32 threads or per warp)
  int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
  int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
  int64_t k = getMmaShape()[2].cast<IntegerAttr>().getInt();

  if (aType.isF64()) {
    // exception to 8-by-8-128b fundamental tensor core tile size
    shapeK = 4;
    numElementA = 1;
    numElementB = 1;
  } else if (aType.isF32() || aType.isBF16() || aType.isF16() ||
             aType.isInteger(8) || aType.isInteger(4)) {
    // 8-by-8-128b fundamental tensor core tile size
    int operandBitwidth = aType.getIntOrFloatBitWidth();
    shapeK = 128 / operandBitwidth;     // 128b wide shapeK
    numElementA = 32 / operandBitwidth; // 32b wide operand A
    numElementB = 32 / operandBitwidth; // 32b wide operand B
  } else {
    return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) "
                          "supported by nvgpu.mma.sync";
  }

  //
  // Basic verification
  //

  // verify warp-wide size for vector a
  if (aShape[0] * aShape[1] * kThreads != m * k)
    return emitOpError() << "expected " << m * k
                         << " warp-wide matrix A elements";

  // verify warp-wide size for vector b
  if (bShape[0] * bShape[1] * kThreads != k * n)
    return emitOpError() << "expected " << k * n
                         << " warp-wide matrix B elements";

  // verify warp-wide size for vector c
  if (cShape[0] * cShape[1] * kThreads != m * n)
    return emitOpError() << "expected " << m * n
                         << " warp-wide matrix C elements";

  //
  // Extended verification
  //

  // tiles of fundamental tensor core operations
  int64_t mTile = m / shapeM;
  int64_t nTile = n / shapeN;
  int64_t kTile = k / shapeK;

  // verify shape of aVector
  if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA)))
    return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile
                         << " x " << numElementA << ")";

  // verify shape of bVector
  if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB)))
    return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile
                         << " x " << numElementB << ")";

  // verify shape of cVector
  if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC)))
    return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile
                         << " x " << numElementC << ")";

  return success();
}

//===----------------------------------------------------------------------===//
// NVGPU_LdMatrixOp
//===----------------------------------------------------------------------===//
LogicalResult LdMatrixOp::verify() {

  // ldmatrix reads data from source in shared memory
  auto srcMemref = getSrcMemref().getType().cast<MemRefType>();

  // ldmatrix writes data to result/destination in vector registers
  auto resVector = getRes().getType().cast<VectorType>();

  // vector register shape, element type, and bitwidth
  ArrayRef<int64_t> resShape = resVector.getShape();
  Type resType = resVector.getElementType();
  int64_t elementBitWidth = resType.getIntOrFloatBitWidth();

  // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread
  int64_t numElementsPer32b = 32 / elementBitWidth;

  // number of 8-by-8 tiles
  int64_t numTiles = getNumTiles();

  // transpose elements in vector registers at 16b granularity when true
  bool isTranspose = getTranspose();

  // address space id for shared memory
  unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();

  //
  // verification
  //

  if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace))
    return emitError()
           << "expected nvgpu.ldmatrix srcMemref must have memory space "
           << smemAddressSpace;
  if (elementBitWidth > 32)
    return emitError() << "nvgpu.ldmatrix works for 32b or lower";
  if (isTranspose && !(elementBitWidth == 16))
    return emitError()
           << "nvgpu.ldmatrix transpose works only at 16b granularity";
  if (!(resShape[1] == numElementsPer32b))
    return emitError() << "expected vector register shape[1] = "
                       << numElementsPer32b;
  if (!(resShape[0] == numTiles))
    return emitError()
           << "expected vector register shape[0] and numTiles to match";

  return success();
}

#define GET_OP_CLASSES
#include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc"
