1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17 #include "mlir/Dialect/GPU/Transforms/Passes.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Pass/Pass.h"
23
24 using namespace mlir;
25
26 namespace {
27
28 struct GpuAllReduceRewriter {
29 using AccumulatorFactory = std::function<Value(Value, Value)>;
30
GpuAllReduceRewriter__anon5d3f75af0111::GpuAllReduceRewriter31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
32 PatternRewriter &rewriter)
33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
34 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
35 indexType(IndexType::get(reduceOp.getContext())),
36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
37
38 /// Creates an all_reduce across the workgroup.
39 ///
40 /// First reduce the elements within a subgroup. The first invocation of each
41 /// subgroup writes the intermediate result to workgroup memory. After
42 /// synchronizing the workgroup, the first subgroup reduces the values from
43 /// workgroup memory. The result is broadcasted to all invocations through
44 /// workgroup memory.
45 ///
46 /// %subgroup_reduce = `createSubgroupReduce(%operand)`
47 /// cf.cond_br %is_first_lane, ^then1, ^continue1
48 /// ^then1:
49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50 /// cf.br ^continue1
51 /// ^continue1:
52 /// gpu.barrier
53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54 /// cf.cond_br %is_valid_subgroup, ^then2, ^continue2
55 /// ^then2:
56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
58 /// store %all_reduce, %workgroup_buffer[%zero]
59 /// llvm.br ^continue2
60 /// ^continue2:
61 /// gpu.barrier
62 /// %result = load %workgroup_buffer[%zero]
63 /// return %result
64 ///
rewrite__anon5d3f75af0111::GpuAllReduceRewriter65 void rewrite() {
66 rewriter.setInsertionPoint(reduceOp);
67
68 // Compute linear invocation index and workgroup size.
69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
81
82 // Compute lane id (invocation id withing the subgroup).
83 Value subgroupMask =
84 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
86 Value isFirstLane =
87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88 create<arith::ConstantIntOp>(0, int32Type));
89
90 Value numThreadsWithSmallerSubgroupId =
91 create<arith::SubIOp>(invocationIdx, laneId);
92 // The number of active invocations starting from the current subgroup.
93 // The consumers do not require the value to be clamped to the size of the
94 // subgroup.
95 Value activeWidth =
96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
97
98 // Create factory for op which accumulates to values.
99 AccumulatorFactory accumFactory = getFactory();
100 assert(accumFactory && "failed to create accumulator factory");
101
102 // Reduce elements within each subgroup to produce the intermediate results.
103 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
104 reduceOp.value(), accumFactory);
105
106 // Add workgroup buffer to parent function for intermediate result.
107 Value buffer = createWorkgroupBuffer();
108
109 // Write the intermediate results to workgroup memory, using the first lane
110 // of each subgroup.
111 createPredicatedBlock(isFirstLane, [&] {
112 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113 Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114 create<memref::StoreOp>(subgroupReduce, buffer, index);
115 });
116 create<gpu::BarrierOp>();
117
118 // Compute number of active subgroups.
119 Value biasedBlockSize =
120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123 invocationIdx, numSubgroups);
124
125 // Use the first numSubgroups invocations to reduce the intermediate results
126 // from workgroup memory. The final result is written to workgroup memory
127 // again.
128 Value zero = create<arith::ConstantIndexOp>(0);
129 createPredicatedBlock(isValidSubgroup, [&] {
130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131 Value value = create<memref::LoadOp>(valueType, buffer, index);
132 Value result =
133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134 create<memref::StoreOp>(result, buffer, zero);
135 });
136
137 // Synchronize workgroup and load result from workgroup memory.
138 create<gpu::BarrierOp>();
139 Value result = create<memref::LoadOp>(valueType, buffer, zero);
140
141 rewriter.replaceOp(reduceOp, result);
142 }
143
144 private:
145 // Shortcut to create an op from rewriter using loc as the first argument.
146 template <typename T, typename... Args>
create__anon5d3f75af0111::GpuAllReduceRewriter147 T create(Args... args) {
148 return rewriter.create<T>(loc, std::forward<Args>(args)...);
149 }
150
151 // Creates dimension op of type T, with the result casted to int32.
152 template <typename T>
getDimOp__anon5d3f75af0111::GpuAllReduceRewriter153 Value getDimOp(gpu::Dimension dimension) {
154 Value dim = create<T>(indexType, dimension);
155 return create<arith::IndexCastOp>(int32Type, dim);
156 }
157
158 /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon5d3f75af0111::GpuAllReduceRewriter159 Value createWorkgroupBuffer() {
160 // TODO: Pick a proper location for the attribution.
161 int workgroupMemoryAddressSpace =
162 gpu::GPUDialect::getWorkgroupAddressSpace();
163 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
164 workgroupMemoryAddressSpace);
165 return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
166 }
167
168 /// Returns an accumulator factory using either the op attribute or the body
169 /// region.
getFactory__anon5d3f75af0111::GpuAllReduceRewriter170 AccumulatorFactory getFactory() {
171 auto &body = reduceOp.body();
172 if (!body.empty())
173 return getFactory(body);
174 auto opAttr = reduceOp.op();
175 if (opAttr)
176 return getFactory(*opAttr);
177 return AccumulatorFactory();
178 }
179
180 /// Returns an accumulator factory that clones the body. The body's entry
181 /// block is expected to have 2 arguments. The gpu.yield return the
182 /// accumulated value of the same type.
getFactory__anon5d3f75af0111::GpuAllReduceRewriter183 AccumulatorFactory getFactory(Region &body) {
184 return AccumulatorFactory([&](Value lhs, Value rhs) {
185 Block *block = rewriter.getInsertionBlock();
186 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
187
188 // Insert accumulator body between split block.
189 BlockAndValueMapping mapping;
190 mapping.map(body.getArgument(0), lhs);
191 mapping.map(body.getArgument(1), rhs);
192 rewriter.cloneRegionBefore(body, *split->getParent(),
193 split->getIterator(), mapping);
194
195 // Add branch before inserted body, into body.
196 block = block->getNextNode();
197 create<cf::BranchOp>(block, ValueRange());
198
199 // Replace all gpu.yield ops with branch out of body.
200 for (; block != split; block = block->getNextNode()) {
201 Operation *terminator = block->getTerminator();
202 if (!isa<gpu::YieldOp>(terminator))
203 continue;
204 rewriter.setInsertionPointToEnd(block);
205 rewriter.replaceOpWithNewOp<cf::BranchOp>(
206 terminator, split, ValueRange(terminator->getOperand(0)));
207 }
208
209 // Return accumulator result.
210 rewriter.setInsertionPointToStart(split);
211 return split->addArgument(lhs.getType(), lhs.getLoc());
212 });
213 }
214
215 /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon5d3f75af0111::GpuAllReduceRewriter216 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
217 bool isFloatingPoint = valueType.isa<FloatType>();
218 switch (opName) {
219 case gpu::AllReduceOperation::ADD:
220 return isFloatingPoint ? getFactory<arith::AddFOp>()
221 : getFactory<arith::AddIOp>();
222 case gpu::AllReduceOperation::MUL:
223 return isFloatingPoint ? getFactory<arith::MulFOp>()
224 : getFactory<arith::MulIOp>();
225 case gpu::AllReduceOperation::AND:
226 return getFactory<arith::AndIOp>();
227 case gpu::AllReduceOperation::OR:
228 return getFactory<arith::OrIOp>();
229 case gpu::AllReduceOperation::XOR:
230 return getFactory<arith::XOrIOp>();
231 case gpu::AllReduceOperation::MAX:
232 return isFloatingPoint
233 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
234 arith::CmpFPredicate::UGT>()
235 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
236 arith::CmpIPredicate::ugt>();
237 case gpu::AllReduceOperation::MIN:
238 return isFloatingPoint
239 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
240 arith::CmpFPredicate::ULT>()
241 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
242 arith::CmpIPredicate::ult>();
243 }
244 llvm_unreachable("unknown GPU AllReduceOperation");
245 }
246
247 /// Returns an accumulator factory that creates an op of type T.
248 template <typename T>
getFactory__anon5d3f75af0111::GpuAllReduceRewriter249 AccumulatorFactory getFactory() {
250 return [&](Value lhs, Value rhs) {
251 return create<T>(lhs.getType(), lhs, rhs);
252 };
253 }
254
255 /// Returns an accumulator for comparison such as min, max. T is the type
256 /// of the compare op.
257 template <typename T, typename PredicateEnum, PredicateEnum predicate>
getCmpFactory__anon5d3f75af0111::GpuAllReduceRewriter258 AccumulatorFactory getCmpFactory() const {
259 return [&](Value lhs, Value rhs) {
260 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
261 return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
262 };
263 }
264
265 /// Creates an if-block skeleton and calls the two factories to generate the
266 /// ops in the `then` and `else` block..
267 ///
268 /// llvm.cond_br %condition, ^then, ^continue
269 /// ^then:
270 /// %then_operands = `thenOpsFactory()`
271 /// llvm.br ^continue(%then_operands)
272 /// ^else:
273 /// %else_operands = `elseOpsFactory()`
274 /// llvm.br ^continue(%else_operands)
275 /// ^continue(%block_operands):
276 ///
277 template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon5d3f75af0111::GpuAllReduceRewriter278 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
279 ElseOpsFactory &&elseOpsFactory) {
280 Block *currentBlock = rewriter.getInsertionBlock();
281 auto currentPoint = rewriter.getInsertionPoint();
282
283 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
284 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
285 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
286
287 rewriter.setInsertionPointToEnd(currentBlock);
288 create<cf::CondBranchOp>(condition, thenBlock,
289 /*trueOperands=*/ArrayRef<Value>(), elseBlock,
290 /*falseOperands=*/ArrayRef<Value>());
291
292 rewriter.setInsertionPointToStart(thenBlock);
293 auto thenOperands = thenOpsFactory();
294 create<cf::BranchOp>(continueBlock, thenOperands);
295
296 rewriter.setInsertionPointToStart(elseBlock);
297 auto elseOperands = elseOpsFactory();
298 create<cf::BranchOp>(continueBlock, elseOperands);
299
300 assert(thenOperands.size() == elseOperands.size());
301 rewriter.setInsertionPointToStart(continueBlock);
302 for (auto operand : thenOperands)
303 continueBlock->addArgument(operand.getType(), operand.getLoc());
304 }
305
306 /// Shortcut for createIf with empty else block and no block operands.
307 template <typename Factory>
createPredicatedBlock__anon5d3f75af0111::GpuAllReduceRewriter308 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
309 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
310 "predicatedOpsFactory should not return any value");
311 createIf(
312 condition,
313 [&] {
314 predicatedOpsFactory();
315 return ArrayRef<Value>();
316 },
317 [&] { return ArrayRef<Value>(); });
318 }
319
320 /// Creates a reduction across the first activeWidth lanes of a subgroup, or
321 /// the entire subgroup if activeWidth is larger than the subgroup width.
322 /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon5d3f75af0111::GpuAllReduceRewriter323 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
324 AccumulatorFactory &accumFactory) {
325 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
326 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
327 activeWidth, subgroupSize);
328 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
329
330 createIf(
331 isPartialSubgroup,
332 // Generate reduction over a (potentially) partial subgroup.
333 [&] {
334 Value value = operand;
335 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
336 // lane is within the active range. The accumulated value is available
337 // in the first lane.
338 for (int i = 1; i < kSubgroupSize; i <<= 1) {
339 Value offset = create<arith::ConstantIntOp>(i, int32Type);
340 auto shuffleOp = create<gpu::ShuffleOp>(
341 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
342 // Skip the accumulation if the shuffle op read from a lane outside
343 // of the active range.
344 createIf(
345 shuffleOp.getResult(1),
346 [&] {
347 return SmallVector<Value, 1>{
348 accumFactory(value, shuffleOp.getResult(0))};
349 },
350 [&] { return llvm::makeArrayRef(value); });
351 value = rewriter.getInsertionBlock()->getArgument(0);
352 }
353 return SmallVector<Value, 1>{value};
354 },
355 // Generate a reduction over the entire subgroup. This is a
356 // specialization of the above reduction with unconditional
357 // accumulation.
358 [&] {
359 Value value = operand;
360 for (int i = 1; i < kSubgroupSize; i <<= 1) {
361 Value offset = create<arith::ConstantIntOp>(i, int32Type);
362 auto shuffleOp =
363 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
364 gpu::ShuffleMode::XOR);
365 value = accumFactory(value, shuffleOp.getResult(0));
366 }
367 return SmallVector<Value, 1>{value};
368 });
369 return rewriter.getInsertionBlock()->getArgument(0);
370 }
371
372 /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon5d3f75af0111::GpuAllReduceRewriter373 Value getDivideBySubgroupSize(Value value) {
374 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
375 return create<arith::DivSIOp>(int32Type, value, subgroupSize);
376 }
377
378 gpu::GPUFuncOp funcOp;
379 gpu::AllReduceOp reduceOp;
380 PatternRewriter &rewriter;
381
382 Location loc;
383 Type valueType;
384 Type indexType;
385 IntegerType int32Type;
386
387 static constexpr int kSubgroupSize = 32;
388 };
389
390 struct GpuAllReduceConversion : public RewritePattern {
GpuAllReduceConversion__anon5d3f75af0111::GpuAllReduceConversion391 explicit GpuAllReduceConversion(MLIRContext *context)
392 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
393
matchAndRewrite__anon5d3f75af0111::GpuAllReduceConversion394 LogicalResult matchAndRewrite(Operation *op,
395 PatternRewriter &rewriter) const override {
396 auto funcOp = cast<gpu::GPUFuncOp>(op);
397 auto callback = [&](gpu::AllReduceOp reduceOp) {
398 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
399 // Performing a rewrite invalidates the walk iterator. Report interrupt
400 // so that we can start a new walk until all all_reduce ops are replaced.
401 return WalkResult::interrupt();
402 };
403 while (funcOp.walk(callback).wasInterrupted()) {
404 }
405 return success();
406 }
407 };
408 } // namespace
409
populateGpuAllReducePatterns(RewritePatternSet & patterns)410 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
411 patterns.add<GpuAllReduceConversion>(patterns.getContext());
412 }
413