1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements transforms to optimize accesses to shared memory. 10 // 11 //===----------------------------------------------------------------------===// 12 #include "PassDetail.h" 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" 17 #include "mlir/Dialect/NVGPU/Passes.h" 18 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/Interfaces/SideEffectInterfaces.h" 21 #include "mlir/Support/LogicalResult.h" 22 #include "llvm/ADT/STLExtras.h" 23 #include "llvm/Support/MathExtras.h" 24 25 using namespace mlir; 26 using namespace mlir::nvgpu; 27 28 /// The size of a shared memory line according to NV documentation. 29 constexpr int64_t kSharedMemoryLineSizeBytes = 128; 30 /// We optimize for 128bit accesses, but this can be made an argument in the 31 /// future. 32 constexpr int64_t kDefaultVectorSizeBits = 128; 33 34 /// Uses `srcIndexValue` to permute `tgtIndexValue` via 35 /// `result = xor(floordiv(srcIdxVal,permuteEveryN), 36 /// floordiv(tgtIdxVal,vectorSize))) 37 /// + tgtIdxVal % vectorSize` 38 /// This is done using an optimized sequence of `arith` operations. 39 static Value permuteVectorOffset(OpBuilder &b, Location loc, 40 ArrayRef<Value> indices, MemRefType memrefTy, 41 int64_t srcDim, int64_t tgtDim) { 42 // Adjust the src index to change how often the permutation changes 43 // if necessary. 44 Value src = indices[srcDim]; 45 46 // We only want to permute every N iterations of the target dim where N is 47 // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). 48 const int64_t permuteEveryN = std::max<int64_t>( 49 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * 50 memrefTy.getElementTypeBitWidth()) / 51 8)); 52 53 // clang-format off 54 // Index bit representation (b0 = least significant bit) for dim(1) 55 // of a `memref<?x?xDT>` is as follows: 56 // N := log2(128/elementSizeBits) 57 // M := log2(dimSize(1)) 58 // then 59 // bits[0:N] = sub-vector element offset 60 // bits[N:M] = vector index 61 // clang-format on 62 int64_t N = 63 llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); 64 int64_t M = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); 65 66 // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. 67 int64_t mask = (1LL << (M - N)) - 1; 68 if (permuteEveryN > 1) 69 mask = mask << llvm::Log2_64(permuteEveryN); 70 Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask); 71 srcBits = b.create<arith::AndIOp>(loc, src, srcBits); 72 73 // Use the src bits to permute the target bits b[N:M] containing the 74 // vector offset. 75 if (permuteEveryN > 1) { 76 int64_t shlBits = N - llvm::Log2_64(permuteEveryN); 77 if (shlBits > 0) { 78 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits); 79 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); 80 } else if (shlBits < 0) { 81 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits); 82 srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal); 83 } 84 } else { 85 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, N); 86 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal); 87 } 88 89 Value permutedVectorIdx = 90 b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits); 91 return permutedVectorIdx; 92 } 93 94 static void transformIndices(OpBuilder &builder, Location loc, 95 SmallVector<Value, 4> &indices, 96 MemRefType memrefTy, int64_t srcDim, 97 int64_t tgtDim) { 98 indices[tgtDim] = 99 permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); 100 } 101 102 Operation::operand_range getIndices(Operation *op) { 103 if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op)) 104 return ldmatrixOp.getIndices(); 105 if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op)) 106 return copyOp.getDstIndices(); 107 if (auto loadOp = dyn_cast<memref::LoadOp>(op)) 108 return loadOp.getIndices(); 109 if (auto storeOp = dyn_cast<memref::StoreOp>(op)) 110 return storeOp.getIndices(); 111 if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op)) 112 return vectorReadOp.getIndices(); 113 if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op)) 114 return vectorStoreOp.getIndices(); 115 llvm_unreachable("unsupported op type"); 116 } 117 118 void setIndices(Operation *op, ArrayRef<Value> indices) { 119 if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op)) 120 return ldmatrixOp.getIndicesMutable().assign(indices); 121 if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op)) 122 return copyOp.getDstIndicesMutable().assign(indices); 123 if (auto loadOp = dyn_cast<memref::LoadOp>(op)) 124 return loadOp.getIndicesMutable().assign(indices); 125 if (auto storeOp = dyn_cast<memref::StoreOp>(op)) 126 return storeOp.getIndicesMutable().assign(indices); 127 if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op)) 128 return vectorReadOp.getIndicesMutable().assign(indices); 129 if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op)) 130 return vectorStoreOp.getIndicesMutable().assign(indices); 131 llvm_unreachable("unsupported op type"); 132 } 133 134 /// Return all operations within `parentOp` that read from or write to 135 /// `shmMemRef`. 136 static LogicalResult 137 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, 138 SmallVector<Operation *, 16> &readOps, 139 SmallVector<Operation *, 16> &writeOps) { 140 parentOp->walk([&](Operation *op) { 141 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op); 142 if (!iface) 143 return; 144 Optional<MemoryEffects::EffectInstance> effect = 145 iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef); 146 if (effect) { 147 readOps.push_back(op); 148 return; 149 } 150 effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef); 151 if (effect) 152 writeOps.push_back(op); 153 }); 154 155 // Restrict to a supported set of ops. We also require at least 2D access, 156 // although this could be relaxed. 157 if (llvm::any_of(readOps, [](Operation *op) { 158 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) || 159 getIndices(op).size() < 2; 160 })) 161 return failure(); 162 if (llvm::any_of(writeOps, [](Operation *op) { 163 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>( 164 op) || 165 getIndices(op).size() < 2; 166 })) 167 return failure(); 168 169 return success(); 170 } 171 172 mlir::LogicalResult 173 mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, 174 Value memrefValue) { 175 auto memRefType = memrefValue.getType().dyn_cast<MemRefType>(); 176 if (!memRefType || memRefType.getMemorySpaceAsInt() != 177 gpu::GPUDialect::getWorkgroupAddressSpace()) 178 return failure(); 179 180 // Abort if the given value has any sub-views; we do not do any alias 181 // analysis. 182 bool hasSubView = false; 183 parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; }); 184 if (hasSubView) 185 return failure(); 186 187 // Check if this is necessary given the assumption of 128b accesses: 188 // If dim[rank-1] is small enough to fit 8 rows in a 128B line. 189 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); 190 const int64_t rowsPerLine = 191 (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / 192 rowSize; 193 const int64_t threadGroupSize = 194 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); 195 if (rowsPerLine >= threadGroupSize) 196 return failure(); 197 198 // Get sets of operations within the function that read/write to shared 199 // memory. 200 SmallVector<Operation *, 16> shmReadOps; 201 SmallVector<Operation *, 16> shmWriteOps; 202 if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, 203 shmWriteOps))) 204 return failure(); 205 206 if (shmReadOps.empty() || shmWriteOps.empty()) 207 return failure(); 208 209 OpBuilder builder(parentOp->getContext()); 210 211 int64_t tgtDim = memRefType.getRank() - 1; 212 int64_t srcDim = memRefType.getRank() - 2; 213 214 // Transform indices for the ops writing to shared memory. 215 while (!shmWriteOps.empty()) { 216 Operation *shmWriteOp = shmWriteOps.back(); 217 shmWriteOps.pop_back(); 218 builder.setInsertionPoint(shmWriteOp); 219 220 auto indices = getIndices(shmWriteOp); 221 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); 222 transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, 223 memRefType, srcDim, tgtDim); 224 setIndices(shmWriteOp, transformedIndices); 225 } 226 227 // Transform indices for the ops reading from shared memory. 228 while (!shmReadOps.empty()) { 229 Operation *shmReadOp = shmReadOps.back(); 230 shmReadOps.pop_back(); 231 builder.setInsertionPoint(shmReadOp); 232 233 auto indices = getIndices(shmReadOp); 234 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end()); 235 transformIndices(builder, shmReadOp->getLoc(), transformedIndices, 236 memRefType, srcDim, tgtDim); 237 setIndices(shmReadOp, transformedIndices); 238 } 239 240 return success(); 241 } 242 243 namespace { 244 class OptimizeSharedMemoryPass 245 : public OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> { 246 public: 247 OptimizeSharedMemoryPass() = default; 248 249 void runOnOperation() override { 250 Operation *op = getOperation(); 251 SmallVector<memref::AllocOp> shmAllocOps; 252 op->walk([&](memref::AllocOp allocOp) { 253 if (allocOp.getMemref() 254 .getType() 255 .cast<MemRefType>() 256 .getMemorySpaceAsInt() != 257 gpu::GPUDialect::getWorkgroupAddressSpace()) 258 return; 259 shmAllocOps.push_back(allocOp); 260 }); 261 for (auto allocOp : shmAllocOps) { 262 if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), 263 allocOp.getMemref()))) 264 return; 265 } 266 } 267 }; 268 } // namespace 269 270 std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() { 271 return std::make_unique<OptimizeSharedMemoryPass>(); 272 } 273