1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/GPU/Passes.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct GpuAllReduceRewriter { 27 using AccumulatorFactory = std::function<Value(Value, Value)>; 28 29 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, 30 PatternRewriter &rewriter_) 31 : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), 32 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 33 indexType(IndexType::get(reduceOp.getContext())), 34 int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {} 35 36 /// Creates an all_reduce across the workgroup. 37 /// 38 /// First reduce the elements within a subgroup. The first invocation of each 39 /// subgroup writes the intermediate result to workgroup memory. After 40 /// synchronizing the workgroup, the first subgroup reduces the values from 41 /// workgroup memory. The result is broadcasted to all invocations through 42 /// workgroup memory. 43 /// 44 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 45 /// cond_br %is_first_lane, ^then1, ^continue1 46 /// ^then1: 47 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 48 /// br ^continue1 49 /// ^continue1: 50 /// gpu.barrier 51 /// %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups 52 /// cond_br %is_valid_subgroup, ^then2, ^continue2 53 /// ^then2: 54 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 55 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 56 /// store %all_reduce, %workgroup_buffer[%zero] 57 /// llvm.br ^continue2 58 /// ^continue2: 59 /// gpu.barrier 60 /// %result = load %workgroup_buffer[%zero] 61 /// return %result 62 /// 63 void rewrite() { 64 rewriter.setInsertionPoint(reduceOp); 65 66 // Compute linear invocation index and workgroup size. 67 Value dimX = getDimOp<gpu::BlockDimOp>("x"); 68 Value dimY = getDimOp<gpu::BlockDimOp>("y"); 69 Value dimZ = getDimOp<gpu::BlockDimOp>("z"); 70 Value tidX = getDimOp<gpu::ThreadIdOp>("x"); 71 Value tidY = getDimOp<gpu::ThreadIdOp>("y"); 72 Value tidZ = getDimOp<gpu::ThreadIdOp>("z"); 73 Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY); 74 Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY); 75 Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX); 76 Value tmp4 = create<MulIOp>(int32Type, dimX, dimY); 77 Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX); 78 Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ); 79 80 // Compute lane id (invocation id withing the subgroup). 81 Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type); 82 Value laneId = create<AndOp>(invocationIdx, subgroupMask); 83 Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId, 84 create<ConstantIntOp>(0, int32Type)); 85 86 Value numThreadsWithSmallerSubgroupId = 87 create<SubIOp>(invocationIdx, laneId); 88 // The number of active invocations starting from the current subgroup. 89 // The consumers do not require the value to be clamped to the size of the 90 // subgroup. 91 Value activeWidth = 92 create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 93 94 // Create factory for op which accumulates to values. 95 AccumulatorFactory accumFactory = getFactory(); 96 assert(accumFactory && "failed to create accumulator factory"); 97 98 // Reduce elements within each subgroup to produce the intermediate results. 99 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 100 reduceOp.value(), accumFactory); 101 102 // Add workgroup buffer to parent function for intermediate result. 103 Value buffer = createWorkgroupBuffer(); 104 105 // Write the intermediate results to workgroup memory, using the first lane 106 // of each subgroup. 107 createPredicatedBlock(isFirstLane, [&] { 108 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 109 Value index = create<IndexCastOp>(indexType, subgroupId); 110 create<StoreOp>(subgroupReduce, buffer, index); 111 }); 112 create<gpu::BarrierOp>(); 113 114 // Compute number of active subgroups. 115 Value biasedBlockSize = 116 create<AddIOp>(int32Type, workgroupSize, subgroupMask); 117 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 118 Value isValidSubgroup = 119 create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups); 120 121 // Use the first numSubgroups invocations to reduce the intermediate results 122 // from workgroup memory. The final result is written to workgroup memory 123 // again. 124 Value zero = create<ConstantIndexOp>(0); 125 createPredicatedBlock(isValidSubgroup, [&] { 126 Value index = create<IndexCastOp>(indexType, invocationIdx); 127 Value value = create<LoadOp>(valueType, buffer, index); 128 Value result = 129 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 130 create<StoreOp>(result, buffer, zero); 131 }); 132 133 // Synchronize workgroup and load result from workgroup memory. 134 create<gpu::BarrierOp>(); 135 Value result = create<LoadOp>(valueType, buffer, zero); 136 137 rewriter.replaceOp(reduceOp, result); 138 } 139 140 private: 141 // Shortcut to create an op from rewriter using loc as the first argument. 142 template <typename T, typename... Args> T create(Args... args) { 143 return rewriter.create<T>(loc, std::forward<Args>(args)...); 144 } 145 146 // Creates dimension op of type T, with the result casted to int32. 147 template <typename T> Value getDimOp(StringRef dimension) { 148 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension)); 149 return create<IndexCastOp>(int32Type, dim); 150 } 151 152 /// Adds type to funcOp's workgroup attributions. 153 Value createWorkgroupBuffer() { 154 int workgroupMemoryAddressSpace = 3; 155 auto bufferType = 156 MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{}, 157 workgroupMemoryAddressSpace); 158 return funcOp.addWorkgroupAttribution(bufferType); 159 } 160 161 /// Returns an accumulator factory using either the op attribute or the body 162 /// region. 163 AccumulatorFactory getFactory() { 164 auto &body = reduceOp.body(); 165 if (!body.empty()) 166 return getFactory(body); 167 auto opAttr = reduceOp.op(); 168 if (opAttr) 169 return getFactory(*opAttr); 170 return AccumulatorFactory(); 171 } 172 173 /// Returns an accumulator factory that clones the body. The body's entry 174 /// block is expected to have 2 arguments. The gpu.yield return the 175 /// accumulated value of the same type. 176 AccumulatorFactory getFactory(Region &body) { 177 return AccumulatorFactory([&](Value lhs, Value rhs) { 178 Block *block = rewriter.getInsertionBlock(); 179 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 180 181 // Insert accumulator body between split block. 182 BlockAndValueMapping mapping; 183 mapping.map(body.front().getArgument(0), lhs); 184 mapping.map(body.front().getArgument(1), rhs); 185 rewriter.cloneRegionBefore(body, *split->getParent(), 186 split->getIterator(), mapping); 187 188 // Add branch before inserted body, into body. 189 block = block->getNextNode(); 190 create<BranchOp>(block, ValueRange()); 191 192 // Replace all gpu.yield ops with branch out of body. 193 for (; block != split; block = block->getNextNode()) { 194 Operation *terminator = block->getTerminator(); 195 if (!isa<gpu::YieldOp>(terminator)) 196 continue; 197 rewriter.setInsertionPointToEnd(block); 198 rewriter.replaceOpWithNewOp<BranchOp>( 199 terminator, split, ValueRange(terminator->getOperand(0))); 200 } 201 202 // Return accumulator result. 203 rewriter.setInsertionPointToStart(split); 204 return split->addArgument(lhs.getType()); 205 }); 206 } 207 208 /// Returns an accumulator factory that creates an op specified by opName. 209 AccumulatorFactory getFactory(StringRef opName) { 210 bool isFloatingPoint = valueType.isa<FloatType>(); 211 if (opName == "add") 212 return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>(); 213 if (opName == "mul") 214 return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>(); 215 if (opName == "and") { 216 return getFactory<AndOp>(); 217 } 218 if (opName == "or") { 219 return getFactory<OrOp>(); 220 } 221 if (opName == "xor") { 222 return getFactory<XOrOp>(); 223 } 224 if (opName == "max") { 225 return isFloatingPoint 226 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>() 227 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>(); 228 } 229 if (opName == "min") { 230 return isFloatingPoint 231 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>() 232 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>(); 233 } 234 return AccumulatorFactory(); 235 } 236 237 /// Returns an accumulator factory that creates an op of type T. 238 template <typename T> AccumulatorFactory getFactory() { 239 return [&](Value lhs, Value rhs) { 240 return create<T>(lhs.getType(), lhs, rhs); 241 }; 242 } 243 244 /// Returns an accumulator for comparison such as min, max. T is the type 245 /// of the compare op. 246 template <typename T, typename PredicateEnum, PredicateEnum predicate> 247 AccumulatorFactory getCmpFactory() const { 248 return [&](Value lhs, Value rhs) { 249 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 250 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 251 }; 252 } 253 254 /// Creates an if-block skeleton and calls the two factories to generate the 255 /// ops in the `then` and `else` block.. 256 /// 257 /// llvm.cond_br %condition, ^then, ^continue 258 /// ^then: 259 /// %then_operands = `thenOpsFactory()` 260 /// llvm.br ^continue(%then_operands) 261 /// ^else: 262 /// %else_operands = `elseOpsFactory()` 263 /// llvm.br ^continue(%else_operands) 264 /// ^continue(%block_operands): 265 /// 266 template <typename ThenOpsFactory, typename ElseOpsFactory> 267 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 268 ElseOpsFactory &&elseOpsFactory) { 269 Block *currentBlock = rewriter.getInsertionBlock(); 270 auto currentPoint = rewriter.getInsertionPoint(); 271 272 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 273 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 274 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 275 276 rewriter.setInsertionPointToEnd(currentBlock); 277 create<CondBranchOp>(condition, thenBlock, 278 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 279 /*falseOperands=*/ArrayRef<Value>()); 280 281 rewriter.setInsertionPointToStart(thenBlock); 282 auto thenOperands = thenOpsFactory(); 283 create<BranchOp>(continueBlock, thenOperands); 284 285 rewriter.setInsertionPointToStart(elseBlock); 286 auto elseOperands = elseOpsFactory(); 287 create<BranchOp>(continueBlock, elseOperands); 288 289 assert(thenOperands.size() == elseOperands.size()); 290 rewriter.setInsertionPointToStart(continueBlock); 291 for (auto operand : thenOperands) 292 continueBlock->addArgument(operand.getType()); 293 } 294 295 /// Shortcut for createIf with empty else block and no block operands. 296 template <typename Factory> 297 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 298 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 299 "predicatedOpsFactory should not return any value"); 300 createIf( 301 condition, 302 [&] { 303 predicatedOpsFactory(); 304 return ArrayRef<Value>(); 305 }, 306 [&] { return ArrayRef<Value>(); }); 307 } 308 309 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 310 /// the entire subgroup if activeWidth is larger than the subgroup width. 311 /// The first lane returns the result, all others return values are undefined. 312 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 313 AccumulatorFactory &accumFactory) { 314 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 315 Value isPartialSubgroup = 316 create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize); 317 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 318 auto xorAttr = rewriter.getStringAttr("xor"); 319 320 createIf( 321 isPartialSubgroup, 322 // Generate reduction over a (potentially) partial subgroup. 323 [&] { 324 Value value = operand; 325 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 326 // lane is within the active range. The accumulated value is available 327 // in the first lane. 328 for (int i = 1; i < kSubgroupSize; i <<= 1) { 329 Value offset = create<ConstantIntOp>(i, int32Type); 330 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 331 activeWidth, xorAttr); 332 // Skip the accumulation if the shuffle op read from a lane outside 333 // of the active range. 334 createIf( 335 shuffleOp.getResult(1), 336 [&] { 337 return SmallVector<Value, 1>{ 338 accumFactory(value, shuffleOp.getResult(0))}; 339 }, 340 [&] { return llvm::makeArrayRef(value); }); 341 value = rewriter.getInsertionBlock()->getArgument(0); 342 } 343 return SmallVector<Value, 1>{value}; 344 }, 345 // Generate a reduction over the entire subgroup. This is a 346 // specialization of the above reduction with unconditional 347 // accumulation. 348 [&] { 349 Value value = operand; 350 for (int i = 1; i < kSubgroupSize; i <<= 1) { 351 Value offset = create<ConstantIntOp>(i, int32Type); 352 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 353 subgroupSize, xorAttr); 354 value = accumFactory(value, shuffleOp.getResult(0)); 355 } 356 return SmallVector<Value, 1>{value}; 357 }); 358 return rewriter.getInsertionBlock()->getArgument(0); 359 } 360 361 /// Returns value divided by the subgroup size (i.e. 32). 362 Value getDivideBySubgroupSize(Value value) { 363 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 364 return create<SignedDivIOp>(int32Type, value, subgroupSize); 365 } 366 367 gpu::GPUFuncOp funcOp; 368 gpu::AllReduceOp reduceOp; 369 PatternRewriter &rewriter; 370 371 Location loc; 372 Type valueType; 373 Type indexType; 374 Type int32Type; 375 376 static constexpr int kSubgroupSize = 32; 377 }; 378 379 struct GpuAllReduceConversion : public RewritePattern { 380 explicit GpuAllReduceConversion(MLIRContext *context) 381 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 382 383 LogicalResult matchAndRewrite(Operation *op, 384 PatternRewriter &rewriter) const override { 385 auto funcOp = cast<gpu::GPUFuncOp>(op); 386 auto callback = [&](gpu::AllReduceOp reduceOp) { 387 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 388 // Performing a rewrite invalidates the walk iterator. Report interrupt 389 // so that we can start a new walk until all all_reduce ops are replaced. 390 return WalkResult::interrupt(); 391 }; 392 while (funcOp.walk(callback).wasInterrupted()) { 393 } 394 return success(); 395 } 396 }; 397 } // namespace 398 399 void mlir::populateGpuRewritePatterns(MLIRContext *context, 400 OwningRewritePatternList &patterns) { 401 patterns.insert<GpuAllReduceConversion>(context); 402 } 403