1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/GPU/Passes.h" 16 #include "mlir/Dialect/StandardOps/Ops.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct GpuAllReduceRewriter { 27 using AccumulatorFactory = std::function<Value(Value, Value)>; 28 29 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, 30 PatternRewriter &rewriter_) 31 : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), 32 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 33 indexType(IndexType::get(reduceOp.getContext())), 34 int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {} 35 36 /// Creates an all_reduce across the workgroup. 37 /// 38 /// First reduce the elements within a subgroup. The first invocation of each 39 /// subgroup writes the intermediate result to workgroup memory. After 40 /// synchronizing the workgroup, the first subgroup reduces the values from 41 /// workgroup memory. The result is broadcasted to all invocations through 42 /// workgroup memory. 43 /// 44 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 45 /// cond_br %is_first_lane, ^then1, ^continue1 46 /// ^then1: 47 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 48 /// br ^continue1 49 /// ^continue1: 50 /// gpu.barrier 51 /// %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups 52 /// cond_br %is_valid_subgroup, ^then2, ^continue2 53 /// ^then2: 54 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 55 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 56 /// store %all_reduce, %workgroup_buffer[%zero] 57 /// llvm.br ^continue2 58 /// ^continue2: 59 /// gpu.barrier 60 /// %result = load %workgroup_buffer[%zero] 61 /// return %result 62 /// 63 void rewrite() { 64 rewriter.setInsertionPoint(reduceOp); 65 66 // Compute linear invocation index and workgroup size. 67 Value dimX = getDimOp<gpu::BlockDimOp>("x"); 68 Value dimY = getDimOp<gpu::BlockDimOp>("y"); 69 Value dimZ = getDimOp<gpu::BlockDimOp>("z"); 70 Value tidX = getDimOp<gpu::ThreadIdOp>("x"); 71 Value tidY = getDimOp<gpu::ThreadIdOp>("y"); 72 Value tidZ = getDimOp<gpu::ThreadIdOp>("z"); 73 Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY); 74 Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY); 75 Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX); 76 Value tmp4 = create<MulIOp>(int32Type, dimX, dimY); 77 Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX); 78 Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ); 79 80 // Compute lane id (invocation id withing the subgroup). 81 Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type); 82 Value laneId = create<AndOp>(invocationIdx, subgroupMask); 83 Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId, 84 create<ConstantIntOp>(0, int32Type)); 85 86 Value numThreadsWithSmallerSubgroupId = 87 create<SubIOp>(invocationIdx, laneId); 88 // The number of active invocations starting from the current subgroup. 89 // The consumers do not require the value to be clamped to the size of the 90 // subgroup. 91 Value activeWidth = 92 create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 93 94 // Create factory for op which accumulates to values. 95 AccumulatorFactory accumFactory = getFactory(); 96 assert(accumFactory && "failed to create accumulator factory"); 97 98 // Reduce elements within each subgroup to produce the intermediate results. 99 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 100 reduceOp.value(), accumFactory); 101 102 // Add workgroup buffer to parent function for intermediate result. 103 Value buffer = createWorkgroupBuffer(); 104 105 // Write the intermediate results to workgroup memory, using the first lane 106 // of each subgroup. 107 createPredicatedBlock(isFirstLane, [&] { 108 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 109 Value index = create<IndexCastOp>(indexType, subgroupId); 110 create<StoreOp>(subgroupReduce, buffer, index); 111 }); 112 create<gpu::BarrierOp>(); 113 114 // Compute number of active subgroups. 115 Value biasedBlockSize = 116 create<AddIOp>(int32Type, workgroupSize, subgroupMask); 117 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 118 Value isValidSubgroup = 119 create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups); 120 121 // Use the first numSubgroups invocations to reduce the intermediate results 122 // from workgroup memory. The final result is written to workgroup memory 123 // again. 124 Value zero = create<ConstantIndexOp>(0); 125 createPredicatedBlock(isValidSubgroup, [&] { 126 Value index = create<IndexCastOp>(indexType, invocationIdx); 127 Value value = create<LoadOp>(valueType, buffer, index); 128 Value result = 129 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 130 create<StoreOp>(result, buffer, zero); 131 }); 132 133 // Synchronize workgroup and load result from workgroup memory. 134 create<gpu::BarrierOp>(); 135 Value result = create<LoadOp>(valueType, buffer, zero); 136 137 rewriter.replaceOp(reduceOp, result); 138 } 139 140 private: 141 // Shortcut to create an op from rewriter using loc as the first argument. 142 template <typename T, typename... Args> T create(Args... args) { 143 return rewriter.create<T>(loc, std::forward<Args>(args)...); 144 } 145 146 // Creates dimension op of type T, with the result casted to int32. 147 template <typename T> Value getDimOp(StringRef dimension) { 148 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension)); 149 return create<IndexCastOp>(int32Type, dim); 150 } 151 152 /// Adds type to funcOp's workgroup attributions. 153 Value createWorkgroupBuffer() { 154 int workgroupMemoryAddressSpace = 3; 155 auto bufferType = 156 MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{}, 157 workgroupMemoryAddressSpace); 158 return funcOp.addWorkgroupAttribution(bufferType); 159 } 160 161 /// Returns an accumulator factory using either the op attribute or the body 162 /// region. 163 AccumulatorFactory getFactory() { 164 auto &body = reduceOp.body(); 165 if (!body.empty()) 166 return getFactory(body); 167 auto opAttr = reduceOp.op(); 168 if (opAttr) 169 return getFactory(*opAttr); 170 return AccumulatorFactory(); 171 } 172 173 /// Returns an accumulator factory that clones the body. The body's entry 174 /// block is expected to have 2 arguments. The gpu.yield return the 175 /// accumulated value of the same type. 176 AccumulatorFactory getFactory(Region &body) { 177 return AccumulatorFactory([&](Value lhs, Value rhs) { 178 Block *block = rewriter.getInsertionBlock(); 179 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 180 181 // Insert accumulator body between split block. 182 BlockAndValueMapping mapping; 183 mapping.map(body.front().getArgument(0), lhs); 184 mapping.map(body.front().getArgument(1), rhs); 185 rewriter.cloneRegionBefore(body, *split->getParent(), 186 split->getIterator(), mapping); 187 188 // Add branch before inserted body, into body. 189 block = block->getNextNode(); 190 create<BranchOp>(block, ValueRange()); 191 192 // Replace all gpu.yield ops with branch out of body. 193 for (; block != split; block = block->getNextNode()) { 194 Operation *terminator = block->getTerminator(); 195 if (!isa<gpu::YieldOp>(terminator)) 196 continue; 197 rewriter.setInsertionPointToEnd(block); 198 rewriter.replaceOpWithNewOp<BranchOp>( 199 terminator, split, ValueRange(terminator->getOperand(0))); 200 } 201 202 // Return accumulator result. 203 rewriter.setInsertionPointToStart(split); 204 return split->addArgument(lhs.getType()); 205 }); 206 } 207 208 /// Returns an accumulator factory that creates an op specified by opName. 209 AccumulatorFactory getFactory(StringRef opName) { 210 bool isFloatingPoint = valueType.isa<FloatType>(); 211 if (opName == "add") 212 return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>(); 213 if (opName == "mul") 214 return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>(); 215 return AccumulatorFactory(); 216 } 217 218 /// Returns an accumulator factory that creates an op of type T. 219 template <typename T> AccumulatorFactory getFactory() { 220 return [&](Value lhs, Value rhs) { 221 return create<T>(lhs.getType(), lhs, rhs); 222 }; 223 } 224 225 /// Creates an if-block skeleton and calls the two factories to generate the 226 /// ops in the `then` and `else` block.. 227 /// 228 /// llvm.cond_br %condition, ^then, ^continue 229 /// ^then: 230 /// %then_operands = `thenOpsFactory()` 231 /// llvm.br ^continue(%then_operands) 232 /// ^else: 233 /// %else_operands = `elseOpsFactory()` 234 /// llvm.br ^continue(%else_operands) 235 /// ^continue(%block_operands): 236 /// 237 template <typename ThenOpsFactory, typename ElseOpsFactory> 238 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 239 ElseOpsFactory &&elseOpsFactory) { 240 Block *currentBlock = rewriter.getInsertionBlock(); 241 auto currentPoint = rewriter.getInsertionPoint(); 242 243 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 244 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 245 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 246 247 rewriter.setInsertionPointToEnd(currentBlock); 248 create<CondBranchOp>(condition, thenBlock, 249 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 250 /*falseOperands=*/ArrayRef<Value>()); 251 252 rewriter.setInsertionPointToStart(thenBlock); 253 auto thenOperands = thenOpsFactory(); 254 create<BranchOp>(continueBlock, thenOperands); 255 256 rewriter.setInsertionPointToStart(elseBlock); 257 auto elseOperands = elseOpsFactory(); 258 create<BranchOp>(continueBlock, elseOperands); 259 260 assert(thenOperands.size() == elseOperands.size()); 261 rewriter.setInsertionPointToStart(continueBlock); 262 for (auto operand : thenOperands) 263 continueBlock->addArgument(operand.getType()); 264 } 265 266 /// Shortcut for createIf with empty else block and no block operands. 267 template <typename Factory> 268 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 269 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 270 "predicatedOpsFactory should not return any value"); 271 createIf( 272 condition, 273 [&] { 274 predicatedOpsFactory(); 275 return ArrayRef<Value>(); 276 }, 277 [&] { return ArrayRef<Value>(); }); 278 } 279 280 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 281 /// the entire subgroup if activeWidth is larger than the subgroup width. 282 /// The first lane returns the result, all others return values are undefined. 283 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 284 AccumulatorFactory &accumFactory) { 285 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 286 Value isPartialSubgroup = 287 create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize); 288 SmallVector<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 289 auto xorAttr = rewriter.getStringAttr("xor"); 290 291 createIf( 292 isPartialSubgroup, 293 // Generate reduction over a (potentially) partial subgroup. 294 [&] { 295 Value value = operand; 296 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 297 // lane is within the active range. The accumulated value is available 298 // in the first lane. 299 for (int i = 1; i < kSubgroupSize; i <<= 1) { 300 Value offset = create<ConstantIntOp>(i, int32Type); 301 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 302 activeWidth, xorAttr); 303 // Skip the accumulation if the shuffle op read from a lane outside 304 // of the active range. 305 createIf( 306 shuffleOp.getResult(1), 307 [&] { 308 return SmallVector<Value, 1>{ 309 accumFactory(value, shuffleOp.getResult(0))}; 310 }, 311 [&] { return llvm::makeArrayRef(value); }); 312 value = rewriter.getInsertionBlock()->getArgument(0); 313 } 314 return SmallVector<Value, 1>{value}; 315 }, 316 // Generate a reduction over the entire subgroup. This is a 317 // specialization of the above reduction with unconditional 318 // accumulation. 319 [&] { 320 Value value = operand; 321 for (int i = 1; i < kSubgroupSize; i <<= 1) { 322 Value offset = create<ConstantIntOp>(i, int32Type); 323 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 324 subgroupSize, xorAttr); 325 value = accumFactory(value, shuffleOp.getResult(0)); 326 } 327 return SmallVector<Value, 1>{value}; 328 }); 329 return rewriter.getInsertionBlock()->getArgument(0); 330 } 331 332 /// Returns value divided by the subgroup size (i.e. 32). 333 Value getDivideBySubgroupSize(Value value) { 334 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 335 return create<SignedDivIOp>(int32Type, value, subgroupSize); 336 } 337 338 gpu::GPUFuncOp funcOp; 339 gpu::AllReduceOp reduceOp; 340 PatternRewriter &rewriter; 341 342 Location loc; 343 Type valueType; 344 Type indexType; 345 Type int32Type; 346 347 static constexpr int kSubgroupSize = 32; 348 }; 349 350 struct GpuAllReduceConversion : public RewritePattern { 351 explicit GpuAllReduceConversion(MLIRContext *context) 352 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 353 354 PatternMatchResult matchAndRewrite(Operation *op, 355 PatternRewriter &rewriter) const override { 356 auto funcOp = cast<gpu::GPUFuncOp>(op); 357 auto callback = [&](gpu::AllReduceOp reduceOp) { 358 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 359 // Performing a rewrite invalidates the walk iterator. Report interrupt 360 // so that we can start a new walk until all all_reduce ops are replaced. 361 return WalkResult::interrupt(); 362 }; 363 while (funcOp.walk(callback).wasInterrupted()) { 364 } 365 return matchSuccess(); 366 } 367 }; 368 } // namespace 369 370 void mlir::populateGpuRewritePatterns(MLIRContext *context, 371 OwningRewritePatternList &patterns) { 372 patterns.insert<GpuAllReduceConversion>(context); 373 } 374