1 //===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements transforms to optimize accesses to shared memory.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
17 #include "mlir/Dialect/NVGPU/Passes.h"
18 #include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/Interfaces/SideEffectInterfaces.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/MathExtras.h"
24
25 using namespace mlir;
26 using namespace mlir::nvgpu;
27
28 /// The size of a shared memory line according to NV documentation.
29 constexpr int64_t kSharedMemoryLineSizeBytes = 128;
30 /// We optimize for 128bit accesses, but this can be made an argument in the
31 /// future.
32 constexpr int64_t kDefaultVectorSizeBits = 128;
33
34 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
35 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
36 /// floordiv(tgtIdxVal,vectorSize)))
37 /// + tgtIdxVal % vectorSize`
38 /// This is done using an optimized sequence of `arith` operations.
permuteVectorOffset(OpBuilder & b,Location loc,ArrayRef<Value> indices,MemRefType memrefTy,int64_t srcDim,int64_t tgtDim)39 static Value permuteVectorOffset(OpBuilder &b, Location loc,
40 ArrayRef<Value> indices, MemRefType memrefTy,
41 int64_t srcDim, int64_t tgtDim) {
42 // Adjust the src index to change how often the permutation changes
43 // if necessary.
44 Value src = indices[srcDim];
45
46 // We only want to permute every N iterations of the target dim where N is
47 // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
48 const int64_t permuteEveryN = std::max<int64_t>(
49 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
50 memrefTy.getElementTypeBitWidth()) /
51 8));
52
53 // clang-format off
54 // Index bit representation (b0 = least significant bit) for dim(1)
55 // of a `memref<?x?xDT>` is as follows:
56 // N := log2(128/elementSizeBits)
57 // M := log2(dimSize(1))
58 // then
59 // bits[0:N] = sub-vector element offset
60 // bits[N:M] = vector index
61 // clang-format on
62 int64_t N =
63 llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
64 int64_t M = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
65
66 // Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
67 int64_t mask = (1LL << (M - N)) - 1;
68 if (permuteEveryN > 1)
69 mask = mask << llvm::Log2_64(permuteEveryN);
70 Value srcBits = b.create<arith::ConstantIndexOp>(loc, mask);
71 srcBits = b.create<arith::AndIOp>(loc, src, srcBits);
72
73 // Use the src bits to permute the target bits b[N:M] containing the
74 // vector offset.
75 if (permuteEveryN > 1) {
76 int64_t shlBits = N - llvm::Log2_64(permuteEveryN);
77 if (shlBits > 0) {
78 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, shlBits);
79 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
80 } else if (shlBits < 0) {
81 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, -1 * shlBits);
82 srcBits = b.createOrFold<arith::ShRUIOp>(loc, srcBits, finalShiftVal);
83 }
84 } else {
85 Value finalShiftVal = b.create<arith::ConstantIndexOp>(loc, N);
86 srcBits = b.createOrFold<arith::ShLIOp>(loc, srcBits, finalShiftVal);
87 }
88
89 Value permutedVectorIdx =
90 b.create<arith::XOrIOp>(loc, indices[tgtDim], srcBits);
91 return permutedVectorIdx;
92 }
93
transformIndices(OpBuilder & builder,Location loc,SmallVector<Value,4> & indices,MemRefType memrefTy,int64_t srcDim,int64_t tgtDim)94 static void transformIndices(OpBuilder &builder, Location loc,
95 SmallVector<Value, 4> &indices,
96 MemRefType memrefTy, int64_t srcDim,
97 int64_t tgtDim) {
98 indices[tgtDim] =
99 permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
100 }
101
getIndices(Operation * op)102 Operation::operand_range getIndices(Operation *op) {
103 if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
104 return ldmatrixOp.getIndices();
105 if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
106 return copyOp.getDstIndices();
107 if (auto loadOp = dyn_cast<memref::LoadOp>(op))
108 return loadOp.getIndices();
109 if (auto storeOp = dyn_cast<memref::StoreOp>(op))
110 return storeOp.getIndices();
111 if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
112 return vectorReadOp.getIndices();
113 if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
114 return vectorStoreOp.getIndices();
115 llvm_unreachable("unsupported op type");
116 }
117
setIndices(Operation * op,ArrayRef<Value> indices)118 void setIndices(Operation *op, ArrayRef<Value> indices) {
119 if (auto ldmatrixOp = dyn_cast<LdMatrixOp>(op))
120 return ldmatrixOp.getIndicesMutable().assign(indices);
121 if (auto copyOp = dyn_cast<DeviceAsyncCopyOp>(op))
122 return copyOp.getDstIndicesMutable().assign(indices);
123 if (auto loadOp = dyn_cast<memref::LoadOp>(op))
124 return loadOp.getIndicesMutable().assign(indices);
125 if (auto storeOp = dyn_cast<memref::StoreOp>(op))
126 return storeOp.getIndicesMutable().assign(indices);
127 if (auto vectorReadOp = dyn_cast<vector::LoadOp>(op))
128 return vectorReadOp.getIndicesMutable().assign(indices);
129 if (auto vectorStoreOp = dyn_cast<vector::StoreOp>(op))
130 return vectorStoreOp.getIndicesMutable().assign(indices);
131 llvm_unreachable("unsupported op type");
132 }
133
134 /// Return all operations within `parentOp` that read from or write to
135 /// `shmMemRef`.
136 static LogicalResult
getShmReadAndWriteOps(Operation * parentOp,Value shmMemRef,SmallVector<Operation *,16> & readOps,SmallVector<Operation *,16> & writeOps)137 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
138 SmallVector<Operation *, 16> &readOps,
139 SmallVector<Operation *, 16> &writeOps) {
140 parentOp->walk([&](Operation *op) {
141 MemoryEffectOpInterface iface = dyn_cast<MemoryEffectOpInterface>(op);
142 if (!iface)
143 return;
144 Optional<MemoryEffects::EffectInstance> effect =
145 iface.getEffectOnValue<MemoryEffects::Read>(shmMemRef);
146 if (effect) {
147 readOps.push_back(op);
148 return;
149 }
150 effect = iface.getEffectOnValue<MemoryEffects::Write>(shmMemRef);
151 if (effect)
152 writeOps.push_back(op);
153 });
154
155 // Restrict to a supported set of ops. We also require at least 2D access,
156 // although this could be relaxed.
157 if (llvm::any_of(readOps, [](Operation *op) {
158 return !isa<memref::LoadOp, vector::LoadOp, nvgpu::LdMatrixOp>(op) ||
159 getIndices(op).size() < 2;
160 }))
161 return failure();
162 if (llvm::any_of(writeOps, [](Operation *op) {
163 return !isa<memref::StoreOp, vector::StoreOp, nvgpu::DeviceAsyncCopyOp>(
164 op) ||
165 getIndices(op).size() < 2;
166 }))
167 return failure();
168
169 return success();
170 }
171
172 mlir::LogicalResult
optimizeSharedMemoryReadsAndWrites(Operation * parentOp,Value memrefValue)173 mlir::nvgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
174 Value memrefValue) {
175 auto memRefType = memrefValue.getType().dyn_cast<MemRefType>();
176 if (!memRefType || memRefType.getMemorySpaceAsInt() !=
177 gpu::GPUDialect::getWorkgroupAddressSpace())
178 return failure();
179
180 // Abort if the given value has any sub-views; we do not do any alias
181 // analysis.
182 bool hasSubView = false;
183 parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
184 if (hasSubView)
185 return failure();
186
187 // Check if this is necessary given the assumption of 128b accesses:
188 // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
189 const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
190 const int64_t rowsPerLine =
191 (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
192 rowSize;
193 const int64_t threadGroupSize =
194 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
195 if (rowsPerLine >= threadGroupSize)
196 return failure();
197
198 // Get sets of operations within the function that read/write to shared
199 // memory.
200 SmallVector<Operation *, 16> shmReadOps;
201 SmallVector<Operation *, 16> shmWriteOps;
202 if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
203 shmWriteOps)))
204 return failure();
205
206 if (shmReadOps.empty() || shmWriteOps.empty())
207 return failure();
208
209 OpBuilder builder(parentOp->getContext());
210
211 int64_t tgtDim = memRefType.getRank() - 1;
212 int64_t srcDim = memRefType.getRank() - 2;
213
214 // Transform indices for the ops writing to shared memory.
215 while (!shmWriteOps.empty()) {
216 Operation *shmWriteOp = shmWriteOps.back();
217 shmWriteOps.pop_back();
218 builder.setInsertionPoint(shmWriteOp);
219
220 auto indices = getIndices(shmWriteOp);
221 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
222 transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
223 memRefType, srcDim, tgtDim);
224 setIndices(shmWriteOp, transformedIndices);
225 }
226
227 // Transform indices for the ops reading from shared memory.
228 while (!shmReadOps.empty()) {
229 Operation *shmReadOp = shmReadOps.back();
230 shmReadOps.pop_back();
231 builder.setInsertionPoint(shmReadOp);
232
233 auto indices = getIndices(shmReadOp);
234 SmallVector<Value, 4> transformedIndices(indices.begin(), indices.end());
235 transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
236 memRefType, srcDim, tgtDim);
237 setIndices(shmReadOp, transformedIndices);
238 }
239
240 return success();
241 }
242
243 namespace {
244 class OptimizeSharedMemoryPass
245 : public OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
246 public:
247 OptimizeSharedMemoryPass() = default;
248
runOnOperation()249 void runOnOperation() override {
250 Operation *op = getOperation();
251 SmallVector<memref::AllocOp> shmAllocOps;
252 op->walk([&](memref::AllocOp allocOp) {
253 if (allocOp.getMemref()
254 .getType()
255 .cast<MemRefType>()
256 .getMemorySpaceAsInt() !=
257 gpu::GPUDialect::getWorkgroupAddressSpace())
258 return;
259 shmAllocOps.push_back(allocOp);
260 });
261 for (auto allocOp : shmAllocOps) {
262 if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
263 allocOp.getMemref())))
264 return;
265 }
266 }
267 };
268 } // namespace
269
createOptimizeSharedMemoryPass()270 std::unique_ptr<Pass> mlir::nvgpu::createOptimizeSharedMemoryPass() {
271 return std::make_unique<OptimizeSharedMemoryPass>();
272 }
273