151b925dfSChristopher Bate //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
251b925dfSChristopher Bate //
351b925dfSChristopher Bate // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
451b925dfSChristopher Bate // See https://llvm.org/LICENSE.txt for license information.
551b925dfSChristopher Bate // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
651b925dfSChristopher Bate //
751b925dfSChristopher Bate //===----------------------------------------------------------------------===//
851b925dfSChristopher Bate //
951b925dfSChristopher Bate // This file implements transforms to optimize accesses to shared memory.
1051b925dfSChristopher Bate //
1151b925dfSChristopher Bate //===----------------------------------------------------------------------===//
1251b925dfSChristopher Bate #include "PassDetail.h"
1351b925dfSChristopher Bate #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
1451b925dfSChristopher Bate #include "mlir/Dialect/GPU/IR/GPUDialect.h"
1551b925dfSChristopher Bate #include "mlir/Dialect/MemRef/IR/MemRef.h"
1651b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
1751b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/Passes.h"
1851b925dfSChristopher Bate #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
1951b925dfSChristopher Bate #include "mlir/Dialect/Vector/IR/VectorOps.h"
2051b925dfSChristopher Bate #include "mlir/Interfaces/SideEffectInterfaces.h"
2151b925dfSChristopher Bate #include "mlir/Support/LogicalResult.h"
2251b925dfSChristopher Bate #include "llvm/ADT/STLExtras.h"
2351b925dfSChristopher Bate #include "llvm/Support/MathExtras.h"
2451b925dfSChristopher Bate 
2551b925dfSChristopher Bate using namespace mlir;
2651b925dfSChristopher Bate using namespace mlir::nvgpu;
2751b925dfSChristopher Bate 
2851b925dfSChristopher Bate /// The size of a shared memory line according to NV documentation.
2951b925dfSChristopher Bate constexpr int64_t kSharedMemoryLineSizeBytes = 128;
3051b925dfSChristopher Bate /// We optimize for 128bit accesses, but this can be made an argument in the
3151b925dfSChristopher Bate /// future.
3251b925dfSChristopher Bate constexpr int64_t kDefaultVectorSizeBits = 128;
3351b925dfSChristopher Bate 
3451b925dfSChristopher Bate /// Uses `srcIndexValue` to permute `tgtIndexValue` via
3551b925dfSChristopher Bate /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
3651b925dfSChristopher Bate ///               floordiv(tgtIdxVal,vectorSize)))
3751b925dfSChristopher Bate ///            + tgtIdxVal % vectorSize`
3851b925dfSChristopher Bate /// This is done using an optimized sequence of `arith` operations.
permuteVectorOffset(OpBuilder & b,Location loc,ArrayRef<Value> indices,MemRefType memrefTy,int64_t srcDim,int64_t tgtDim)3951b925dfSChristopher Bate static Value permuteVectorOffset(OpBuilder &b, Location loc,
4051b925dfSChristopher Bate                                  ArrayRef<Value> indices, MemRefType memrefTy,
4151b925dfSChristopher Bate                                  int64_t srcDim, int64_t tgtDim) {
4251b925dfSChristopher Bate   // Adjust the src index to change how often the permutation changes
4351b925dfSChristopher Bate   // if necessary.
4451b925dfSChristopher Bate   Value src = indices[srcDim];
4551b925dfSChristopher Bate 
4651b925dfSChristopher Bate   // We only want to permute every N iterations of the target dim where N is
4751b925dfSChristopher Bate   // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
4851b925dfSChristopher Bate   const int64_t permuteEveryN = std::max<int64_t>(
4951b925dfSChristopher Bate       1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
5051b925dfSChristopher Bate                                         memrefTy.getElementTypeBitWidth()) /
5151b925dfSChristopher Bate                                        8));
5251b925dfSChristopher Bate 
5351b925dfSChristopher Bate   // clang-format off
5451b925dfSChristopher Bate   // Index bit representation (b0 = least significant bit) for dim(1)
5551b925dfSChristopher Bate   // of a `memref<?x?xDT>` is as follows:
5651b925dfSChristopher Bate   // N := log2(128/elementSizeBits)
5751b925dfSChristopher Bate   // M := log2(dimSize(1))
5851b925dfSChristopher Bate   // then
5951b925dfSChristopher Bate   // bits[0:N] = sub-vector element offset
6051b925dfSChristopher Bate   // bits[N:M] = vector index
6151b925dfSChristopher Bate   // clang-format on
6251b925dfSChristopher Bate   int64_t N =
6351b925dfSChristopher Bate       llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
6451b925dfSChristopher Bate   int64_t M = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
6551b925dfSChristopher Bate 
6651b925dfSChristopher Bate   // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
67829c84ecSChristopher Bate   int64_t mask = (1LL << (M - N)) - 1;
6851b925dfSChristopher Bate   if (permuteEveryN > 1)
6951b925dfSChristopher Bate     mask = mask << llvm::Log2_64(permuteEveryN);
7051b925dfSChristopher Bate   Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
7151b925dfSChristopher Bate   srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
7251b925dfSChristopher Bate 
7351b925dfSChristopher Bate   // Use the src bits to permute the target bits b[N:M] containing the
7451b925dfSChristopher Bate   // vector offset.
7551b925dfSChristopher Bate   if (permuteEveryN > 1) {
7651b925dfSChristopher Bate     int64_t shlBits = N - llvm::Log2_64(permuteEveryN);
7751b925dfSChristopher Bate     if (shlBits > 0) {
7851b925dfSChristopher Bate       Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
7951b925dfSChristopher Bate       srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
8051b925dfSChristopher Bate     } else if (shlBits < 0) {
8151b925dfSChristopher Bate       Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
8251b925dfSChristopher Bate       srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
8351b925dfSChristopher Bate     }
8451b925dfSChristopher Bate   } else {
8551b925dfSChristopher Bate     Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, N);
8651b925dfSChristopher Bate     srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
8751b925dfSChristopher Bate   }
8851b925dfSChristopher Bate 
8951b925dfSChristopher Bate   Value permutedVectorIdx =
9051b925dfSChristopher Bate       b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
9151b925dfSChristopher Bate   return permutedVectorIdx;
9251b925dfSChristopher Bate }
9351b925dfSChristopher Bate 
transformIndices(OpBuilder & builder,Location loc,SmallVector<Value,4> & indices,MemRefType memrefTy,int64_t srcDim,int64_t tgtDim)9451b925dfSChristopher Bate static void transformIndices(OpBuilder &builder, Location loc,
9551b925dfSChristopher Bate                              SmallVector<Value, 4> &indices,
9651b925dfSChristopher Bate                              MemRefType memrefTy, int64_t srcDim,
9751b925dfSChristopher Bate                              int64_t tgtDim) {
9851b925dfSChristopher Bate   indices[tgtDim] =
9951b925dfSChristopher Bate       permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
10051b925dfSChristopher Bate }
10151b925dfSChristopher Bate 
getIndices(Operation * op)10251b925dfSChristopher Bate Operation::operand_range getIndices(Operation *op) {
10351b925dfSChristopher Bate   if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
1048df54a6aSJacques Pienaar     return ldmatrixOp.getIndices();
10551b925dfSChristopher Bate   if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
1068df54a6aSJacques Pienaar     return copyOp.getDstIndices();
10751b925dfSChristopher Bate   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
108*136d746eSJacques Pienaar     return loadOp.getIndices();
10951b925dfSChristopher Bate   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
110*136d746eSJacques Pienaar     return storeOp.getIndices();
11151b925dfSChristopher Bate   if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
11251b925dfSChristopher Bate     return vectorReadOp.getIndices();
11351b925dfSChristopher Bate   if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
11451b925dfSChristopher Bate     return vectorStoreOp.getIndices();
11551b925dfSChristopher Bate   llvm_unreachable("unsupported op type");
11651b925dfSChristopher Bate }
11751b925dfSChristopher Bate 
setIndices(Operation * op,ArrayRef<Value> indices)11851b925dfSChristopher Bate void setIndices(Operation *op, ArrayRef<Value> indices) {
11951b925dfSChristopher Bate   if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
1208df54a6aSJacques Pienaar     return ldmatrixOp.getIndicesMutable().assign(indices);
12151b925dfSChristopher Bate   if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
1228df54a6aSJacques Pienaar     return copyOp.getDstIndicesMutable().assign(indices);
12351b925dfSChristopher Bate   if (auto loadOp = dyn_cast<memref::LoadOp>(op))
124*136d746eSJacques Pienaar     return loadOp.getIndicesMutable().assign(indices);
12551b925dfSChristopher Bate   if (auto storeOp = dyn_cast<memref::StoreOp>(op))
126*136d746eSJacques Pienaar     return storeOp.getIndicesMutable().assign(indices);
12751b925dfSChristopher Bate   if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
12851b925dfSChristopher Bate     return vectorReadOp.getIndicesMutable().assign(indices);
12951b925dfSChristopher Bate   if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
13051b925dfSChristopher Bate     return vectorStoreOp.getIndicesMutable().assign(indices);
13151b925dfSChristopher Bate   llvm_unreachable("unsupported op type");
13251b925dfSChristopher Bate }
13351b925dfSChristopher Bate 
13451b925dfSChristopher Bate /// Return all operations within `parentOp` that read from or write to
13551b925dfSChristopher Bate /// `shmMemRef`.
13651b925dfSChristopher Bate static LogicalResult
getShmReadAndWriteOps(Operation * parentOp,Value shmMemRef,SmallVector<Operation *,16> & readOps,SmallVector<Operation *,16> & writeOps)13751b925dfSChristopher Bate getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
13851b925dfSChristopher Bate                       SmallVector<Operation *, 16> &readOps,
13951b925dfSChristopher Bate                       SmallVector<Operation *, 16> &writeOps) {
14051b925dfSChristopher Bate   parentOp->walk([&](Operation *op) {
14151b925dfSChristopher Bate     MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
14251b925dfSChristopher Bate     if (!iface)
14351b925dfSChristopher Bate       return;
14451b925dfSChristopher Bate     Optional<MemoryEffects::EffectInstance> effect =
14551b925dfSChristopher Bate         iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
14651b925dfSChristopher Bate     if (effect) {
14751b925dfSChristopher Bate       readOps.push_back(op);
14851b925dfSChristopher Bate       return;
14951b925dfSChristopher Bate     }
15051b925dfSChristopher Bate     effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
15151b925dfSChristopher Bate     if (effect)
15251b925dfSChristopher Bate       writeOps.push_back(op);
15351b925dfSChristopher Bate   });
15451b925dfSChristopher Bate 
15551b925dfSChristopher Bate   // Restrict to a supported set of ops. We also require at least 2D access,
15651b925dfSChristopher Bate   // although this could be relaxed.
15751b925dfSChristopher Bate   if (llvm::any_of(readOps, [](Operation *op) {
15851b925dfSChristopher Bate         return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
15951b925dfSChristopher Bate                getIndices(op).size() < 2;
16051b925dfSChristopher Bate       }))
16151b925dfSChristopher Bate     return failure();
16251b925dfSChristopher Bate   if (llvm::any_of(writeOps, [](Operation *op) {
16351b925dfSChristopher Bate         return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
16451b925dfSChristopher Bate                    op) ||
16551b925dfSChristopher Bate                getIndices(op).size() < 2;
16651b925dfSChristopher Bate       }))
16751b925dfSChristopher Bate     return failure();
16851b925dfSChristopher Bate 
16951b925dfSChristopher Bate   return success();
17051b925dfSChristopher Bate }
17151b925dfSChristopher Bate 
17251b925dfSChristopher Bate mlir::LogicalResult
optimizeSharedMemoryReadsAndWrites(Operation * parentOp,Value memrefValue)17351b925dfSChristopher Bate mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
17451b925dfSChristopher Bate                                                 Value memrefValue) {
17551b925dfSChristopher Bate   auto memRefType = memrefValue.getType().dyn_cast<MemRefType>();
17651b925dfSChristopher Bate   if (!memRefType || memRefType.getMemorySpaceAsInt() !=
17751b925dfSChristopher Bate                          gpu::GPUDialect::getWorkgroupAddressSpace())
17851b925dfSChristopher Bate     return failure();
17951b925dfSChristopher Bate 
18051b925dfSChristopher Bate   // Abort if the given value has any sub-views; we do not do any alias
18151b925dfSChristopher Bate   // analysis.
18251b925dfSChristopher Bate   bool hasSubView = false;
18351b925dfSChristopher Bate   parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
18451b925dfSChristopher Bate   if (hasSubView)
18551b925dfSChristopher Bate     return failure();
18651b925dfSChristopher Bate 
18751b925dfSChristopher Bate   // Check if this is necessary given the assumption of 128b accesses:
18851b925dfSChristopher Bate   // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
18951b925dfSChristopher Bate   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
19051b925dfSChristopher Bate   const int64_t rowsPerLine =
19151b925dfSChristopher Bate       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
19251b925dfSChristopher Bate       rowSize;
19351b925dfSChristopher Bate   const int64_t threadGroupSize =
194829c84ecSChristopher Bate       1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
19551b925dfSChristopher Bate   if (rowsPerLine >= threadGroupSize)
19651b925dfSChristopher Bate     return failure();
19751b925dfSChristopher Bate 
19851b925dfSChristopher Bate   // Get sets of operations within the function that read/write to shared
19951b925dfSChristopher Bate   // memory.
20051b925dfSChristopher Bate   SmallVector<Operation *, 16> shmReadOps;
20151b925dfSChristopher Bate   SmallVector<Operation *, 16> shmWriteOps;
20251b925dfSChristopher Bate   if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
20351b925dfSChristopher Bate                                    shmWriteOps)))
20451b925dfSChristopher Bate     return failure();
20551b925dfSChristopher Bate 
20651b925dfSChristopher Bate   if (shmReadOps.empty() || shmWriteOps.empty())
20751b925dfSChristopher Bate     return failure();
20851b925dfSChristopher Bate 
20951b925dfSChristopher Bate   OpBuilder builder(parentOp->getContext());
21051b925dfSChristopher Bate 
21151b925dfSChristopher Bate   int64_t tgtDim = memRefType.getRank() - 1;
21251b925dfSChristopher Bate   int64_t srcDim = memRefType.getRank() - 2;
21351b925dfSChristopher Bate 
21451b925dfSChristopher Bate   // Transform indices for the ops writing to shared memory.
21551b925dfSChristopher Bate   while (!shmWriteOps.empty()) {
21651b925dfSChristopher Bate     Operation *shmWriteOp = shmWriteOps.back();
21751b925dfSChristopher Bate     shmWriteOps.pop_back();
21851b925dfSChristopher Bate     builder.setInsertionPoint(shmWriteOp);
21951b925dfSChristopher Bate 
22051b925dfSChristopher Bate     auto indices = getIndices(shmWriteOp);
22151b925dfSChristopher Bate     SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
22251b925dfSChristopher Bate     transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
22351b925dfSChristopher Bate                      memRefType, srcDim, tgtDim);
22451b925dfSChristopher Bate     setIndices(shmWriteOp, transformedIndices);
22551b925dfSChristopher Bate   }
22651b925dfSChristopher Bate 
22751b925dfSChristopher Bate   // Transform indices for the ops reading from shared memory.
22851b925dfSChristopher Bate   while (!shmReadOps.empty()) {
22951b925dfSChristopher Bate     Operation *shmReadOp = shmReadOps.back();
23051b925dfSChristopher Bate     shmReadOps.pop_back();
23151b925dfSChristopher Bate     builder.setInsertionPoint(shmReadOp);
23251b925dfSChristopher Bate 
23351b925dfSChristopher Bate     auto indices = getIndices(shmReadOp);
23451b925dfSChristopher Bate     SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
23551b925dfSChristopher Bate     transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
23651b925dfSChristopher Bate                      memRefType, srcDim, tgtDim);
23751b925dfSChristopher Bate     setIndices(shmReadOp, transformedIndices);
23851b925dfSChristopher Bate   }
23951b925dfSChristopher Bate 
24051b925dfSChristopher Bate   return success();
24151b925dfSChristopher Bate }
24251b925dfSChristopher Bate 
24351b925dfSChristopher Bate namespace {
24451b925dfSChristopher Bate class OptimizeSharedMemoryPass
24551b925dfSChristopher Bate     : public OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
24651b925dfSChristopher Bate public:
24751b925dfSChristopher Bate   OptimizeSharedMemoryPass() = default;
24851b925dfSChristopher Bate 
runOnOperation()24951b925dfSChristopher Bate   void runOnOperation() override {
25051b925dfSChristopher Bate     Operation *op = getOperation();
25151b925dfSChristopher Bate     SmallVector<memref::AllocOp> shmAllocOps;
25251b925dfSChristopher Bate     op->walk([&](memref::AllocOp allocOp) {
253*136d746eSJacques Pienaar       if (allocOp.getMemref()
254*136d746eSJacques Pienaar               .getType()
255*136d746eSJacques Pienaar               .cast<MemRefType>()
256*136d746eSJacques Pienaar               .getMemorySpaceAsInt() !=
25751b925dfSChristopher Bate           gpu::GPUDialect::getWorkgroupAddressSpace())
25851b925dfSChristopher Bate         return;
25951b925dfSChristopher Bate       shmAllocOps.push_back(allocOp);
26051b925dfSChristopher Bate     });
26151b925dfSChristopher Bate     for (auto allocOp : shmAllocOps) {
26251b925dfSChristopher Bate       if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
263*136d746eSJacques Pienaar                                                     allocOp.getMemref())))
26451b925dfSChristopher Bate         return;
26551b925dfSChristopher Bate     }
26651b925dfSChristopher Bate   }
26751b925dfSChristopher Bate };
26851b925dfSChristopher Bate } // namespace
26951b925dfSChristopher Bate 
createOptimizeSharedMemoryPass()27051b925dfSChristopher Bate std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {
27151b925dfSChristopher Bate   return std::make_unique<OptimizeSharedMemoryPass>();
27251b925dfSChristopher Bate }
273