18b2eb7c4SChristian Sigg //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
28b2eb7c4SChristian Sigg //
38b2eb7c4SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48b2eb7c4SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
58b2eb7c4SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68b2eb7c4SChristian Sigg //
78b2eb7c4SChristian Sigg //===----------------------------------------------------------------------===//
88b2eb7c4SChristian Sigg //
98b2eb7c4SChristian Sigg // This file implements in-dialect lowering of the all-reduce op to a block of
108b2eb7c4SChristian Sigg // simpler instructions.
118b2eb7c4SChristian Sigg //
128b2eb7c4SChristian Sigg //===----------------------------------------------------------------------===//
138b2eb7c4SChristian Sigg 
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15ace01605SRiver Riddle #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16*d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
17*d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
18e2310704SJulian Gross #include "mlir/Dialect/MemRef/IR/MemRef.h"
198b2eb7c4SChristian Sigg #include "mlir/IR/BlockAndValueMapping.h"
208b2eb7c4SChristian Sigg #include "mlir/IR/Builders.h"
218b2eb7c4SChristian Sigg #include "mlir/IR/PatternMatch.h"
228b2eb7c4SChristian Sigg #include "mlir/Pass/Pass.h"
238b2eb7c4SChristian Sigg 
248b2eb7c4SChristian Sigg using namespace mlir;
258b2eb7c4SChristian Sigg 
268b2eb7c4SChristian Sigg namespace {
278b2eb7c4SChristian Sigg 
288b2eb7c4SChristian Sigg struct GpuAllReduceRewriter {
298b2eb7c4SChristian Sigg   using AccumulatorFactory = std::function<Value(Value, Value)>;
308b2eb7c4SChristian Sigg 
GpuAllReduceRewriter__anon5d3f75af0111::GpuAllReduceRewriter3102b6fb21SMehdi Amini   GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp,
3202b6fb21SMehdi Amini                        PatternRewriter &rewriter)
3302b6fb21SMehdi Amini       : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter),
348b2eb7c4SChristian Sigg         loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
358b2eb7c4SChristian Sigg         indexType(IndexType::get(reduceOp.getContext())),
361b97cdf8SRiver Riddle         int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}
378b2eb7c4SChristian Sigg 
388b2eb7c4SChristian Sigg   /// Creates an all_reduce across the workgroup.
398b2eb7c4SChristian Sigg   ///
408b2eb7c4SChristian Sigg   /// First reduce the elements within a subgroup. The first invocation of each
418b2eb7c4SChristian Sigg   /// subgroup writes the intermediate result to workgroup memory. After
428b2eb7c4SChristian Sigg   /// synchronizing the workgroup, the first subgroup reduces the values from
438b2eb7c4SChristian Sigg   /// workgroup memory. The result is broadcasted to all invocations through
448b2eb7c4SChristian Sigg   /// workgroup memory.
458b2eb7c4SChristian Sigg   ///
468b2eb7c4SChristian Sigg   ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
47ace01605SRiver Riddle   ///     cf.cond_br %is_first_lane, ^then1, ^continue1
488b2eb7c4SChristian Sigg   ///   ^then1:
498b2eb7c4SChristian Sigg   ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
50ace01605SRiver Riddle   ///     cf.br ^continue1
518b2eb7c4SChristian Sigg   ///   ^continue1:
528b2eb7c4SChristian Sigg   ///     gpu.barrier
53a54f4eaeSMogball   ///     %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups
54ace01605SRiver Riddle   ///     cf.cond_br %is_valid_subgroup, ^then2, ^continue2
558b2eb7c4SChristian Sigg   ///   ^then2:
568b2eb7c4SChristian Sigg   ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
578b2eb7c4SChristian Sigg   ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
588b2eb7c4SChristian Sigg   ///     store %all_reduce, %workgroup_buffer[%zero]
598b2eb7c4SChristian Sigg   ///     llvm.br ^continue2
608b2eb7c4SChristian Sigg   ///   ^continue2:
618b2eb7c4SChristian Sigg   ///     gpu.barrier
628b2eb7c4SChristian Sigg   ///     %result = load %workgroup_buffer[%zero]
638b2eb7c4SChristian Sigg   ///     return %result
648b2eb7c4SChristian Sigg   ///
rewrite__anon5d3f75af0111::GpuAllReduceRewriter658b2eb7c4SChristian Sigg   void rewrite() {
668b2eb7c4SChristian Sigg     rewriter.setInsertionPoint(reduceOp);
678b2eb7c4SChristian Sigg 
688b2eb7c4SChristian Sigg     // Compute linear invocation index and workgroup size.
69aae51255SMogball     Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x);
70aae51255SMogball     Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y);
71aae51255SMogball     Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z);
72aae51255SMogball     Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x);
73aae51255SMogball     Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y);
74aae51255SMogball     Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z);
75a54f4eaeSMogball     Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY);
76a54f4eaeSMogball     Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY);
77a54f4eaeSMogball     Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX);
78a54f4eaeSMogball     Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY);
79a54f4eaeSMogball     Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX);
80a54f4eaeSMogball     Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ);
818b2eb7c4SChristian Sigg 
828b2eb7c4SChristian Sigg     // Compute lane id (invocation id withing the subgroup).
83a54f4eaeSMogball     Value subgroupMask =
84a54f4eaeSMogball         create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type);
85a54f4eaeSMogball     Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask);
86a54f4eaeSMogball     Value isFirstLane =
87a54f4eaeSMogball         create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId,
88a54f4eaeSMogball                               create<arith::ConstantIntOp>(0, int32Type));
898b2eb7c4SChristian Sigg 
908b2eb7c4SChristian Sigg     Value numThreadsWithSmallerSubgroupId =
91a54f4eaeSMogball         create<arith::SubIOp>(invocationIdx, laneId);
928b2eb7c4SChristian Sigg     // The number of active invocations starting from the current subgroup.
938b2eb7c4SChristian Sigg     // The consumers do not require the value to be clamped to the size of the
948b2eb7c4SChristian Sigg     // subgroup.
958b2eb7c4SChristian Sigg     Value activeWidth =
96a54f4eaeSMogball         create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
978b2eb7c4SChristian Sigg 
988b2eb7c4SChristian Sigg     // Create factory for op which accumulates to values.
998b2eb7c4SChristian Sigg     AccumulatorFactory accumFactory = getFactory();
1008b2eb7c4SChristian Sigg     assert(accumFactory && "failed to create accumulator factory");
1018b2eb7c4SChristian Sigg 
1028b2eb7c4SChristian Sigg     // Reduce elements within each subgroup to produce the intermediate results.
1038b2eb7c4SChristian Sigg     Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
1048b2eb7c4SChristian Sigg                                                 reduceOp.value(), accumFactory);
1058b2eb7c4SChristian Sigg 
1068b2eb7c4SChristian Sigg     // Add workgroup buffer to parent function for intermediate result.
1078b2eb7c4SChristian Sigg     Value buffer = createWorkgroupBuffer();
1088b2eb7c4SChristian Sigg 
1098b2eb7c4SChristian Sigg     // Write the intermediate results to workgroup memory, using the first lane
1108b2eb7c4SChristian Sigg     // of each subgroup.
1118b2eb7c4SChristian Sigg     createPredicatedBlock(isFirstLane, [&] {
1128b2eb7c4SChristian Sigg       Value subgroupId = getDivideBySubgroupSize(invocationIdx);
113a54f4eaeSMogball       Value index = create<arith::IndexCastOp>(indexType, subgroupId);
114e2310704SJulian Gross       create<memref::StoreOp>(subgroupReduce, buffer, index);
1158b2eb7c4SChristian Sigg     });
1168b2eb7c4SChristian Sigg     create<gpu::BarrierOp>();
1178b2eb7c4SChristian Sigg 
1188b2eb7c4SChristian Sigg     // Compute number of active subgroups.
1198b2eb7c4SChristian Sigg     Value biasedBlockSize =
120a54f4eaeSMogball         create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask);
1218b2eb7c4SChristian Sigg     Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
122a54f4eaeSMogball     Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
123a54f4eaeSMogball                                                   invocationIdx, numSubgroups);
1248b2eb7c4SChristian Sigg 
1258b2eb7c4SChristian Sigg     // Use the first numSubgroups invocations to reduce the intermediate results
1268b2eb7c4SChristian Sigg     // from workgroup memory. The final result is written to workgroup memory
1278b2eb7c4SChristian Sigg     // again.
128a54f4eaeSMogball     Value zero = create<arith::ConstantIndexOp>(0);
1298b2eb7c4SChristian Sigg     createPredicatedBlock(isValidSubgroup, [&] {
130a54f4eaeSMogball       Value index = create<arith::IndexCastOp>(indexType, invocationIdx);
131e2310704SJulian Gross       Value value = create<memref::LoadOp>(valueType, buffer, index);
1328b2eb7c4SChristian Sigg       Value result =
1338b2eb7c4SChristian Sigg           createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
134e2310704SJulian Gross       create<memref::StoreOp>(result, buffer, zero);
1358b2eb7c4SChristian Sigg     });
1368b2eb7c4SChristian Sigg 
1378b2eb7c4SChristian Sigg     // Synchronize workgroup and load result from workgroup memory.
1388b2eb7c4SChristian Sigg     create<gpu::BarrierOp>();
139e2310704SJulian Gross     Value result = create<memref::LoadOp>(valueType, buffer, zero);
1408b2eb7c4SChristian Sigg 
1418b2eb7c4SChristian Sigg     rewriter.replaceOp(reduceOp, result);
1428b2eb7c4SChristian Sigg   }
1438b2eb7c4SChristian Sigg 
1448b2eb7c4SChristian Sigg private:
1458b2eb7c4SChristian Sigg   // Shortcut to create an op from rewriter using loc as the first argument.
146e2310704SJulian Gross   template <typename T, typename... Args>
create__anon5d3f75af0111::GpuAllReduceRewriter147e2310704SJulian Gross   T create(Args... args) {
1488b2eb7c4SChristian Sigg     return rewriter.create<T>(loc, std::forward<Args>(args)...);
1498b2eb7c4SChristian Sigg   }
1508b2eb7c4SChristian Sigg 
1518b2eb7c4SChristian Sigg   // Creates dimension op of type T, with the result casted to int32.
152e2310704SJulian Gross   template <typename T>
getDimOp__anon5d3f75af0111::GpuAllReduceRewriter153aae51255SMogball   Value getDimOp(gpu::Dimension dimension) {
154aae51255SMogball     Value dim = create<T>(indexType, dimension);
155a54f4eaeSMogball     return create<arith::IndexCastOp>(int32Type, dim);
1568b2eb7c4SChristian Sigg   }
1578b2eb7c4SChristian Sigg 
1588b2eb7c4SChristian Sigg   /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon5d3f75af0111::GpuAllReduceRewriter1598b2eb7c4SChristian Sigg   Value createWorkgroupBuffer() {
160e084679fSRiver Riddle     // TODO: Pick a proper location for the attribution.
161ad398164SWen-Heng (Jack) Chung     int workgroupMemoryAddressSpace =
162ad398164SWen-Heng (Jack) Chung         gpu::GPUDialect::getWorkgroupAddressSpace();
163e41ebbecSVladislav Vinogradov     auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{},
1648b2eb7c4SChristian Sigg                                       workgroupMemoryAddressSpace);
165e084679fSRiver Riddle     return funcOp.addWorkgroupAttribution(bufferType, rewriter.getUnknownLoc());
1668b2eb7c4SChristian Sigg   }
1678b2eb7c4SChristian Sigg 
1688b2eb7c4SChristian Sigg   /// Returns an accumulator factory using either the op attribute or the body
1698b2eb7c4SChristian Sigg   /// region.
getFactory__anon5d3f75af0111::GpuAllReduceRewriter1708b2eb7c4SChristian Sigg   AccumulatorFactory getFactory() {
1718b2eb7c4SChristian Sigg     auto &body = reduceOp.body();
1728b2eb7c4SChristian Sigg     if (!body.empty())
1738b2eb7c4SChristian Sigg       return getFactory(body);
1748b2eb7c4SChristian Sigg     auto opAttr = reduceOp.op();
1758b2eb7c4SChristian Sigg     if (opAttr)
1768b2eb7c4SChristian Sigg       return getFactory(*opAttr);
1778b2eb7c4SChristian Sigg     return AccumulatorFactory();
1788b2eb7c4SChristian Sigg   }
1798b2eb7c4SChristian Sigg 
1808b2eb7c4SChristian Sigg   /// Returns an accumulator factory that clones the body. The body's entry
1818b2eb7c4SChristian Sigg   /// block is expected to have 2 arguments. The gpu.yield return the
1828b2eb7c4SChristian Sigg   /// accumulated value of the same type.
getFactory__anon5d3f75af0111::GpuAllReduceRewriter1838b2eb7c4SChristian Sigg   AccumulatorFactory getFactory(Region &body) {
1848b2eb7c4SChristian Sigg     return AccumulatorFactory([&](Value lhs, Value rhs) {
1858b2eb7c4SChristian Sigg       Block *block = rewriter.getInsertionBlock();
1868b2eb7c4SChristian Sigg       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
1878b2eb7c4SChristian Sigg 
1888b2eb7c4SChristian Sigg       // Insert accumulator body between split block.
1898b2eb7c4SChristian Sigg       BlockAndValueMapping mapping;
190e2b71610SRahul Joshi       mapping.map(body.getArgument(0), lhs);
191e2b71610SRahul Joshi       mapping.map(body.getArgument(1), rhs);
1928b2eb7c4SChristian Sigg       rewriter.cloneRegionBefore(body, *split->getParent(),
1938b2eb7c4SChristian Sigg                                  split->getIterator(), mapping);
1948b2eb7c4SChristian Sigg 
1958b2eb7c4SChristian Sigg       // Add branch before inserted body, into body.
1968b2eb7c4SChristian Sigg       block = block->getNextNode();
197ace01605SRiver Riddle       create<cf::BranchOp>(block, ValueRange());
1988b2eb7c4SChristian Sigg 
1998b2eb7c4SChristian Sigg       // Replace all gpu.yield ops with branch out of body.
2008b2eb7c4SChristian Sigg       for (; block != split; block = block->getNextNode()) {
2018b2eb7c4SChristian Sigg         Operation *terminator = block->getTerminator();
2028b2eb7c4SChristian Sigg         if (!isa<gpu::YieldOp>(terminator))
2038b2eb7c4SChristian Sigg           continue;
2048b2eb7c4SChristian Sigg         rewriter.setInsertionPointToEnd(block);
205ace01605SRiver Riddle         rewriter.replaceOpWithNewOp<cf::BranchOp>(
2068b2eb7c4SChristian Sigg             terminator, split, ValueRange(terminator->getOperand(0)));
2078b2eb7c4SChristian Sigg       }
2088b2eb7c4SChristian Sigg 
2098b2eb7c4SChristian Sigg       // Return accumulator result.
2108b2eb7c4SChristian Sigg       rewriter.setInsertionPointToStart(split);
211e084679fSRiver Riddle       return split->addArgument(lhs.getType(), lhs.getLoc());
2128b2eb7c4SChristian Sigg     });
2138b2eb7c4SChristian Sigg   }
2148b2eb7c4SChristian Sigg 
2158b2eb7c4SChristian Sigg   /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon5d3f75af0111::GpuAllReduceRewriter216aae51255SMogball   AccumulatorFactory getFactory(gpu::AllReduceOperation opName) {
2178b2eb7c4SChristian Sigg     bool isFloatingPoint = valueType.isa<FloatType>();
218aae51255SMogball     switch (opName) {
219aae51255SMogball     case gpu::AllReduceOperation::ADD:
220a54f4eaeSMogball       return isFloatingPoint ? getFactory<arith::AddFOp>()
221a54f4eaeSMogball                              : getFactory<arith::AddIOp>();
222aae51255SMogball     case gpu::AllReduceOperation::MUL:
223a54f4eaeSMogball       return isFloatingPoint ? getFactory<arith::MulFOp>()
224a54f4eaeSMogball                              : getFactory<arith::MulIOp>();
225aae51255SMogball     case gpu::AllReduceOperation::AND:
226a54f4eaeSMogball       return getFactory<arith::AndIOp>();
227aae51255SMogball     case gpu::AllReduceOperation::OR:
228a54f4eaeSMogball       return getFactory<arith::OrIOp>();
229aae51255SMogball     case gpu::AllReduceOperation::XOR:
230a54f4eaeSMogball       return getFactory<arith::XOrIOp>();
231aae51255SMogball     case gpu::AllReduceOperation::MAX:
232c7380995SValentin Clement       return isFloatingPoint
233a54f4eaeSMogball                  ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
234a54f4eaeSMogball                                  arith::CmpFPredicate::UGT>()
235a54f4eaeSMogball                  : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
236a54f4eaeSMogball                                  arith::CmpIPredicate::ugt>();
237aae51255SMogball     case gpu::AllReduceOperation::MIN:
238c7380995SValentin Clement       return isFloatingPoint
239a54f4eaeSMogball                  ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate,
240a54f4eaeSMogball                                  arith::CmpFPredicate::ULT>()
241a54f4eaeSMogball                  : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate,
242a54f4eaeSMogball                                  arith::CmpIPredicate::ult>();
243c7380995SValentin Clement     }
244dc3b9365SAlexandre Ganea     llvm_unreachable("unknown GPU AllReduceOperation");
2458b2eb7c4SChristian Sigg   }
2468b2eb7c4SChristian Sigg 
2478b2eb7c4SChristian Sigg   /// Returns an accumulator factory that creates an op of type T.
248e2310704SJulian Gross   template <typename T>
getFactory__anon5d3f75af0111::GpuAllReduceRewriter249e2310704SJulian Gross   AccumulatorFactory getFactory() {
2508b2eb7c4SChristian Sigg     return [&](Value lhs, Value rhs) {
2518b2eb7c4SChristian Sigg       return create<T>(lhs.getType(), lhs, rhs);
2528b2eb7c4SChristian Sigg     };
2538b2eb7c4SChristian Sigg   }
2548b2eb7c4SChristian Sigg 
2555aacce3dSKazuaki Ishizaki   /// Returns an accumulator for comparison such as min, max. T is the type
256c7380995SValentin Clement   /// of the compare op.
257c7380995SValentin Clement   template <typename T, typename PredicateEnum, PredicateEnum predicate>
getCmpFactory__anon5d3f75af0111::GpuAllReduceRewriter258c7380995SValentin Clement   AccumulatorFactory getCmpFactory() const {
259c7380995SValentin Clement     return [&](Value lhs, Value rhs) {
260c7380995SValentin Clement       Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
261dec8af70SRiver Riddle       return rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs);
262c7380995SValentin Clement     };
263c7380995SValentin Clement   }
264c7380995SValentin Clement 
2658b2eb7c4SChristian Sigg   /// Creates an if-block skeleton and calls the two factories to generate the
2668b2eb7c4SChristian Sigg   /// ops in the `then` and `else` block..
2678b2eb7c4SChristian Sigg   ///
2688b2eb7c4SChristian Sigg   ///     llvm.cond_br %condition, ^then, ^continue
2698b2eb7c4SChristian Sigg   ///   ^then:
2708b2eb7c4SChristian Sigg   ///     %then_operands = `thenOpsFactory()`
2718b2eb7c4SChristian Sigg   ///     llvm.br ^continue(%then_operands)
2728b2eb7c4SChristian Sigg   ///   ^else:
2738b2eb7c4SChristian Sigg   ///     %else_operands = `elseOpsFactory()`
2748b2eb7c4SChristian Sigg   ///     llvm.br ^continue(%else_operands)
2758b2eb7c4SChristian Sigg   ///   ^continue(%block_operands):
2768b2eb7c4SChristian Sigg   ///
2778b2eb7c4SChristian Sigg   template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon5d3f75af0111::GpuAllReduceRewriter2788b2eb7c4SChristian Sigg   void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
2798b2eb7c4SChristian Sigg                 ElseOpsFactory &&elseOpsFactory) {
2808b2eb7c4SChristian Sigg     Block *currentBlock = rewriter.getInsertionBlock();
2818b2eb7c4SChristian Sigg     auto currentPoint = rewriter.getInsertionPoint();
2828b2eb7c4SChristian Sigg 
2838b2eb7c4SChristian Sigg     Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
2848b2eb7c4SChristian Sigg     Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
2858b2eb7c4SChristian Sigg     Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
2868b2eb7c4SChristian Sigg 
2878b2eb7c4SChristian Sigg     rewriter.setInsertionPointToEnd(currentBlock);
288ace01605SRiver Riddle     create<cf::CondBranchOp>(condition, thenBlock,
2898b2eb7c4SChristian Sigg                              /*trueOperands=*/ArrayRef<Value>(), elseBlock,
2908b2eb7c4SChristian Sigg                              /*falseOperands=*/ArrayRef<Value>());
2918b2eb7c4SChristian Sigg 
2928b2eb7c4SChristian Sigg     rewriter.setInsertionPointToStart(thenBlock);
2938b2eb7c4SChristian Sigg     auto thenOperands = thenOpsFactory();
294ace01605SRiver Riddle     create<cf::BranchOp>(continueBlock, thenOperands);
2958b2eb7c4SChristian Sigg 
2968b2eb7c4SChristian Sigg     rewriter.setInsertionPointToStart(elseBlock);
2978b2eb7c4SChristian Sigg     auto elseOperands = elseOpsFactory();
298ace01605SRiver Riddle     create<cf::BranchOp>(continueBlock, elseOperands);
2998b2eb7c4SChristian Sigg 
3008b2eb7c4SChristian Sigg     assert(thenOperands.size() == elseOperands.size());
3018b2eb7c4SChristian Sigg     rewriter.setInsertionPointToStart(continueBlock);
3028b2eb7c4SChristian Sigg     for (auto operand : thenOperands)
303e084679fSRiver Riddle       continueBlock->addArgument(operand.getType(), operand.getLoc());
3048b2eb7c4SChristian Sigg   }
3058b2eb7c4SChristian Sigg 
3068b2eb7c4SChristian Sigg   /// Shortcut for createIf with empty else block and no block operands.
3078b2eb7c4SChristian Sigg   template <typename Factory>
createPredicatedBlock__anon5d3f75af0111::GpuAllReduceRewriter3088b2eb7c4SChristian Sigg   void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
3098b2eb7c4SChristian Sigg     static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
3108b2eb7c4SChristian Sigg                   "predicatedOpsFactory should not return any value");
3118b2eb7c4SChristian Sigg     createIf(
3128b2eb7c4SChristian Sigg         condition,
3138b2eb7c4SChristian Sigg         [&] {
3148b2eb7c4SChristian Sigg           predicatedOpsFactory();
3158b2eb7c4SChristian Sigg           return ArrayRef<Value>();
3168b2eb7c4SChristian Sigg         },
3178b2eb7c4SChristian Sigg         [&] { return ArrayRef<Value>(); });
3188b2eb7c4SChristian Sigg   }
3198b2eb7c4SChristian Sigg 
3208b2eb7c4SChristian Sigg   /// Creates a reduction across the first activeWidth lanes of a subgroup, or
3218b2eb7c4SChristian Sigg   /// the entire subgroup if activeWidth is larger than the subgroup width.
3228b2eb7c4SChristian Sigg   /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon5d3f75af0111::GpuAllReduceRewriter3238b2eb7c4SChristian Sigg   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
3248b2eb7c4SChristian Sigg                              AccumulatorFactory &accumFactory) {
325a54f4eaeSMogball     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
326a54f4eaeSMogball     Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt,
327a54f4eaeSMogball                                                     activeWidth, subgroupSize);
3285fc5c7dbSBenjamin Kramer     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
3298b2eb7c4SChristian Sigg 
3308b2eb7c4SChristian Sigg     createIf(
3318b2eb7c4SChristian Sigg         isPartialSubgroup,
3328b2eb7c4SChristian Sigg         // Generate reduction over a (potentially) partial subgroup.
3338b2eb7c4SChristian Sigg         [&] {
3348b2eb7c4SChristian Sigg           Value value = operand;
3358b2eb7c4SChristian Sigg           // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
3368b2eb7c4SChristian Sigg           // lane is within the active range. The accumulated value is available
3378b2eb7c4SChristian Sigg           // in the first lane.
3388b2eb7c4SChristian Sigg           for (int i = 1; i < kSubgroupSize; i <<= 1) {
339a54f4eaeSMogball             Value offset = create<arith::ConstantIntOp>(i, int32Type);
340aae51255SMogball             auto shuffleOp = create<gpu::ShuffleOp>(
341aae51255SMogball                 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR);
3428b2eb7c4SChristian Sigg             // Skip the accumulation if the shuffle op read from a lane outside
3438b2eb7c4SChristian Sigg             // of the active range.
3448b2eb7c4SChristian Sigg             createIf(
3458b2eb7c4SChristian Sigg                 shuffleOp.getResult(1),
3468b2eb7c4SChristian Sigg                 [&] {
3478b2eb7c4SChristian Sigg                   return SmallVector<Value, 1>{
3488b2eb7c4SChristian Sigg                       accumFactory(value, shuffleOp.getResult(0))};
3498b2eb7c4SChristian Sigg                 },
3508b2eb7c4SChristian Sigg                 [&] { return llvm::makeArrayRef(value); });
3518b2eb7c4SChristian Sigg             value = rewriter.getInsertionBlock()->getArgument(0);
3528b2eb7c4SChristian Sigg           }
3538b2eb7c4SChristian Sigg           return SmallVector<Value, 1>{value};
3548b2eb7c4SChristian Sigg         },
3558b2eb7c4SChristian Sigg         // Generate a reduction over the entire subgroup. This is a
3568b2eb7c4SChristian Sigg         // specialization of the above reduction with unconditional
3578b2eb7c4SChristian Sigg         // accumulation.
3588b2eb7c4SChristian Sigg         [&] {
3598b2eb7c4SChristian Sigg           Value value = operand;
3608b2eb7c4SChristian Sigg           for (int i = 1; i < kSubgroupSize; i <<= 1) {
361a54f4eaeSMogball             Value offset = create<arith::ConstantIntOp>(i, int32Type);
362aae51255SMogball             auto shuffleOp =
363aae51255SMogball                 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize,
364aae51255SMogball                                        gpu::ShuffleMode::XOR);
3658b2eb7c4SChristian Sigg             value = accumFactory(value, shuffleOp.getResult(0));
3668b2eb7c4SChristian Sigg           }
3678b2eb7c4SChristian Sigg           return SmallVector<Value, 1>{value};
3688b2eb7c4SChristian Sigg         });
3698b2eb7c4SChristian Sigg     return rewriter.getInsertionBlock()->getArgument(0);
3708b2eb7c4SChristian Sigg   }
3718b2eb7c4SChristian Sigg 
3728b2eb7c4SChristian Sigg   /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon5d3f75af0111::GpuAllReduceRewriter3738b2eb7c4SChristian Sigg   Value getDivideBySubgroupSize(Value value) {
374a54f4eaeSMogball     Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type);
375a54f4eaeSMogball     return create<arith::DivSIOp>(int32Type, value, subgroupSize);
3768b2eb7c4SChristian Sigg   }
3778b2eb7c4SChristian Sigg 
3788b2eb7c4SChristian Sigg   gpu::GPUFuncOp funcOp;
3798b2eb7c4SChristian Sigg   gpu::AllReduceOp reduceOp;
3808b2eb7c4SChristian Sigg   PatternRewriter &rewriter;
3818b2eb7c4SChristian Sigg 
3828b2eb7c4SChristian Sigg   Location loc;
3838b2eb7c4SChristian Sigg   Type valueType;
3848b2eb7c4SChristian Sigg   Type indexType;
385a54f4eaeSMogball   IntegerType int32Type;
3868b2eb7c4SChristian Sigg 
3878b2eb7c4SChristian Sigg   static constexpr int kSubgroupSize = 32;
3888b2eb7c4SChristian Sigg };
3898b2eb7c4SChristian Sigg 
3908b2eb7c4SChristian Sigg struct GpuAllReduceConversion : public RewritePattern {
GpuAllReduceConversion__anon5d3f75af0111::GpuAllReduceConversion3918b2eb7c4SChristian Sigg   explicit GpuAllReduceConversion(MLIRContext *context)
3928b2eb7c4SChristian Sigg       : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
3938b2eb7c4SChristian Sigg 
matchAndRewrite__anon5d3f75af0111::GpuAllReduceConversion3943145427dSRiver Riddle   LogicalResult matchAndRewrite(Operation *op,
3958b2eb7c4SChristian Sigg                                 PatternRewriter &rewriter) const override {
3968b2eb7c4SChristian Sigg     auto funcOp = cast<gpu::GPUFuncOp>(op);
3978b2eb7c4SChristian Sigg     auto callback = [&](gpu::AllReduceOp reduceOp) {
3988b2eb7c4SChristian Sigg       GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
3998b2eb7c4SChristian Sigg       // Performing a rewrite invalidates the walk iterator. Report interrupt
4008b2eb7c4SChristian Sigg       // so that we can start a new walk until all all_reduce ops are replaced.
4018b2eb7c4SChristian Sigg       return WalkResult::interrupt();
4028b2eb7c4SChristian Sigg     };
4038b2eb7c4SChristian Sigg     while (funcOp.walk(callback).wasInterrupted()) {
4048b2eb7c4SChristian Sigg     }
4053145427dSRiver Riddle     return success();
4068b2eb7c4SChristian Sigg   }
4078b2eb7c4SChristian Sigg };
4088b2eb7c4SChristian Sigg } // namespace
4098b2eb7c4SChristian Sigg 
populateGpuAllReducePatterns(RewritePatternSet & patterns)410dc4e913bSChris Lattner void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
411dc4e913bSChris Lattner   patterns.add<GpuAllReduceConversion>(patterns.getContext());
4128b2eb7c4SChristian Sigg }
413