1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/GPU/GPUDialect.h" 15 #include "mlir/Dialect/GPU/Passes.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/BlockAndValueMapping.h" 18 #include "mlir/IR/Builders.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/Pass/Pass.h" 21 22 using namespace mlir; 23 24 namespace { 25 26 struct GpuAllReduceRewriter { 27 using AccumulatorFactory = std::function<Value(Value, Value)>; 28 29 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, 30 PatternRewriter &rewriter_) 31 : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), 32 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 33 indexType(IndexType::get(reduceOp.getContext())), 34 int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {} 35 36 /// Creates an all_reduce across the workgroup. 37 /// 38 /// First reduce the elements within a subgroup. The first invocation of each 39 /// subgroup writes the intermediate result to workgroup memory. After 40 /// synchronizing the workgroup, the first subgroup reduces the values from 41 /// workgroup memory. The result is broadcasted to all invocations through 42 /// workgroup memory. 43 /// 44 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 45 /// cond_br %is_first_lane, ^then1, ^continue1 46 /// ^then1: 47 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 48 /// br ^continue1 49 /// ^continue1: 50 /// gpu.barrier 51 /// %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups 52 /// cond_br %is_valid_subgroup, ^then2, ^continue2 53 /// ^then2: 54 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 55 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 56 /// store %all_reduce, %workgroup_buffer[%zero] 57 /// llvm.br ^continue2 58 /// ^continue2: 59 /// gpu.barrier 60 /// %result = load %workgroup_buffer[%zero] 61 /// return %result 62 /// 63 void rewrite() { 64 rewriter.setInsertionPoint(reduceOp); 65 66 // Compute linear invocation index and workgroup size. 67 Value dimX = getDimOp<gpu::BlockDimOp>("x"); 68 Value dimY = getDimOp<gpu::BlockDimOp>("y"); 69 Value dimZ = getDimOp<gpu::BlockDimOp>("z"); 70 Value tidX = getDimOp<gpu::ThreadIdOp>("x"); 71 Value tidY = getDimOp<gpu::ThreadIdOp>("y"); 72 Value tidZ = getDimOp<gpu::ThreadIdOp>("z"); 73 Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY); 74 Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY); 75 Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX); 76 Value tmp4 = create<MulIOp>(int32Type, dimX, dimY); 77 Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX); 78 Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ); 79 80 // Compute lane id (invocation id withing the subgroup). 81 Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type); 82 Value laneId = create<AndOp>(invocationIdx, subgroupMask); 83 Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId, 84 create<ConstantIntOp>(0, int32Type)); 85 86 Value numThreadsWithSmallerSubgroupId = 87 create<SubIOp>(invocationIdx, laneId); 88 // The number of active invocations starting from the current subgroup. 89 // The consumers do not require the value to be clamped to the size of the 90 // subgroup. 91 Value activeWidth = 92 create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 93 94 // Create factory for op which accumulates to values. 95 AccumulatorFactory accumFactory = getFactory(); 96 assert(accumFactory && "failed to create accumulator factory"); 97 98 // Reduce elements within each subgroup to produce the intermediate results. 99 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 100 reduceOp.value(), accumFactory); 101 102 // Add workgroup buffer to parent function for intermediate result. 103 Value buffer = createWorkgroupBuffer(); 104 105 // Write the intermediate results to workgroup memory, using the first lane 106 // of each subgroup. 107 createPredicatedBlock(isFirstLane, [&] { 108 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 109 Value index = create<IndexCastOp>(indexType, subgroupId); 110 create<StoreOp>(subgroupReduce, buffer, index); 111 }); 112 create<gpu::BarrierOp>(); 113 114 // Compute number of active subgroups. 115 Value biasedBlockSize = 116 create<AddIOp>(int32Type, workgroupSize, subgroupMask); 117 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 118 Value isValidSubgroup = 119 create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups); 120 121 // Use the first numSubgroups invocations to reduce the intermediate results 122 // from workgroup memory. The final result is written to workgroup memory 123 // again. 124 Value zero = create<ConstantIndexOp>(0); 125 createPredicatedBlock(isValidSubgroup, [&] { 126 Value index = create<IndexCastOp>(indexType, invocationIdx); 127 Value value = create<LoadOp>(valueType, buffer, index); 128 Value result = 129 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 130 create<StoreOp>(result, buffer, zero); 131 }); 132 133 // Synchronize workgroup and load result from workgroup memory. 134 create<gpu::BarrierOp>(); 135 Value result = create<LoadOp>(valueType, buffer, zero); 136 137 rewriter.replaceOp(reduceOp, result); 138 } 139 140 private: 141 // Shortcut to create an op from rewriter using loc as the first argument. 142 template <typename T, typename... Args> T create(Args... args) { 143 return rewriter.create<T>(loc, std::forward<Args>(args)...); 144 } 145 146 // Creates dimension op of type T, with the result casted to int32. 147 template <typename T> Value getDimOp(StringRef dimension) { 148 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension)); 149 return create<IndexCastOp>(int32Type, dim); 150 } 151 152 /// Adds type to funcOp's workgroup attributions. 153 Value createWorkgroupBuffer() { 154 int workgroupMemoryAddressSpace = 155 gpu::GPUDialect::getWorkgroupAddressSpace(); 156 auto bufferType = 157 MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{}, 158 workgroupMemoryAddressSpace); 159 return funcOp.addWorkgroupAttribution(bufferType); 160 } 161 162 /// Returns an accumulator factory using either the op attribute or the body 163 /// region. 164 AccumulatorFactory getFactory() { 165 auto &body = reduceOp.body(); 166 if (!body.empty()) 167 return getFactory(body); 168 auto opAttr = reduceOp.op(); 169 if (opAttr) 170 return getFactory(*opAttr); 171 return AccumulatorFactory(); 172 } 173 174 /// Returns an accumulator factory that clones the body. The body's entry 175 /// block is expected to have 2 arguments. The gpu.yield return the 176 /// accumulated value of the same type. 177 AccumulatorFactory getFactory(Region &body) { 178 return AccumulatorFactory([&](Value lhs, Value rhs) { 179 Block *block = rewriter.getInsertionBlock(); 180 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 181 182 // Insert accumulator body between split block. 183 BlockAndValueMapping mapping; 184 mapping.map(body.getArgument(0), lhs); 185 mapping.map(body.getArgument(1), rhs); 186 rewriter.cloneRegionBefore(body, *split->getParent(), 187 split->getIterator(), mapping); 188 189 // Add branch before inserted body, into body. 190 block = block->getNextNode(); 191 create<BranchOp>(block, ValueRange()); 192 193 // Replace all gpu.yield ops with branch out of body. 194 for (; block != split; block = block->getNextNode()) { 195 Operation *terminator = block->getTerminator(); 196 if (!isa<gpu::YieldOp>(terminator)) 197 continue; 198 rewriter.setInsertionPointToEnd(block); 199 rewriter.replaceOpWithNewOp<BranchOp>( 200 terminator, split, ValueRange(terminator->getOperand(0))); 201 } 202 203 // Return accumulator result. 204 rewriter.setInsertionPointToStart(split); 205 return split->addArgument(lhs.getType()); 206 }); 207 } 208 209 /// Returns an accumulator factory that creates an op specified by opName. 210 AccumulatorFactory getFactory(StringRef opName) { 211 bool isFloatingPoint = valueType.isa<FloatType>(); 212 if (opName == "add") 213 return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>(); 214 if (opName == "mul") 215 return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>(); 216 if (opName == "and") { 217 return getFactory<AndOp>(); 218 } 219 if (opName == "or") { 220 return getFactory<OrOp>(); 221 } 222 if (opName == "xor") { 223 return getFactory<XOrOp>(); 224 } 225 if (opName == "max") { 226 return isFloatingPoint 227 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>() 228 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>(); 229 } 230 if (opName == "min") { 231 return isFloatingPoint 232 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>() 233 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>(); 234 } 235 return AccumulatorFactory(); 236 } 237 238 /// Returns an accumulator factory that creates an op of type T. 239 template <typename T> AccumulatorFactory getFactory() { 240 return [&](Value lhs, Value rhs) { 241 return create<T>(lhs.getType(), lhs, rhs); 242 }; 243 } 244 245 /// Returns an accumulator for comparison such as min, max. T is the type 246 /// of the compare op. 247 template <typename T, typename PredicateEnum, PredicateEnum predicate> 248 AccumulatorFactory getCmpFactory() const { 249 return [&](Value lhs, Value rhs) { 250 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 251 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 252 }; 253 } 254 255 /// Creates an if-block skeleton and calls the two factories to generate the 256 /// ops in the `then` and `else` block.. 257 /// 258 /// llvm.cond_br %condition, ^then, ^continue 259 /// ^then: 260 /// %then_operands = `thenOpsFactory()` 261 /// llvm.br ^continue(%then_operands) 262 /// ^else: 263 /// %else_operands = `elseOpsFactory()` 264 /// llvm.br ^continue(%else_operands) 265 /// ^continue(%block_operands): 266 /// 267 template <typename ThenOpsFactory, typename ElseOpsFactory> 268 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 269 ElseOpsFactory &&elseOpsFactory) { 270 Block *currentBlock = rewriter.getInsertionBlock(); 271 auto currentPoint = rewriter.getInsertionPoint(); 272 273 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 274 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 275 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 276 277 rewriter.setInsertionPointToEnd(currentBlock); 278 create<CondBranchOp>(condition, thenBlock, 279 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 280 /*falseOperands=*/ArrayRef<Value>()); 281 282 rewriter.setInsertionPointToStart(thenBlock); 283 auto thenOperands = thenOpsFactory(); 284 create<BranchOp>(continueBlock, thenOperands); 285 286 rewriter.setInsertionPointToStart(elseBlock); 287 auto elseOperands = elseOpsFactory(); 288 create<BranchOp>(continueBlock, elseOperands); 289 290 assert(thenOperands.size() == elseOperands.size()); 291 rewriter.setInsertionPointToStart(continueBlock); 292 for (auto operand : thenOperands) 293 continueBlock->addArgument(operand.getType()); 294 } 295 296 /// Shortcut for createIf with empty else block and no block operands. 297 template <typename Factory> 298 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 299 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 300 "predicatedOpsFactory should not return any value"); 301 createIf( 302 condition, 303 [&] { 304 predicatedOpsFactory(); 305 return ArrayRef<Value>(); 306 }, 307 [&] { return ArrayRef<Value>(); }); 308 } 309 310 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 311 /// the entire subgroup if activeWidth is larger than the subgroup width. 312 /// The first lane returns the result, all others return values are undefined. 313 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 314 AccumulatorFactory &accumFactory) { 315 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 316 Value isPartialSubgroup = 317 create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize); 318 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 319 auto xorAttr = rewriter.getStringAttr("xor"); 320 321 createIf( 322 isPartialSubgroup, 323 // Generate reduction over a (potentially) partial subgroup. 324 [&] { 325 Value value = operand; 326 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 327 // lane is within the active range. The accumulated value is available 328 // in the first lane. 329 for (int i = 1; i < kSubgroupSize; i <<= 1) { 330 Value offset = create<ConstantIntOp>(i, int32Type); 331 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 332 activeWidth, xorAttr); 333 // Skip the accumulation if the shuffle op read from a lane outside 334 // of the active range. 335 createIf( 336 shuffleOp.getResult(1), 337 [&] { 338 return SmallVector<Value, 1>{ 339 accumFactory(value, shuffleOp.getResult(0))}; 340 }, 341 [&] { return llvm::makeArrayRef(value); }); 342 value = rewriter.getInsertionBlock()->getArgument(0); 343 } 344 return SmallVector<Value, 1>{value}; 345 }, 346 // Generate a reduction over the entire subgroup. This is a 347 // specialization of the above reduction with unconditional 348 // accumulation. 349 [&] { 350 Value value = operand; 351 for (int i = 1; i < kSubgroupSize; i <<= 1) { 352 Value offset = create<ConstantIntOp>(i, int32Type); 353 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 354 subgroupSize, xorAttr); 355 value = accumFactory(value, shuffleOp.getResult(0)); 356 } 357 return SmallVector<Value, 1>{value}; 358 }); 359 return rewriter.getInsertionBlock()->getArgument(0); 360 } 361 362 /// Returns value divided by the subgroup size (i.e. 32). 363 Value getDivideBySubgroupSize(Value value) { 364 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type); 365 return create<SignedDivIOp>(int32Type, value, subgroupSize); 366 } 367 368 gpu::GPUFuncOp funcOp; 369 gpu::AllReduceOp reduceOp; 370 PatternRewriter &rewriter; 371 372 Location loc; 373 Type valueType; 374 Type indexType; 375 Type int32Type; 376 377 static constexpr int kSubgroupSize = 32; 378 }; 379 380 struct GpuAllReduceConversion : public RewritePattern { 381 explicit GpuAllReduceConversion(MLIRContext *context) 382 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 383 384 LogicalResult matchAndRewrite(Operation *op, 385 PatternRewriter &rewriter) const override { 386 auto funcOp = cast<gpu::GPUFuncOp>(op); 387 auto callback = [&](gpu::AllReduceOp reduceOp) { 388 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 389 // Performing a rewrite invalidates the walk iterator. Report interrupt 390 // so that we can start a new walk until all all_reduce ops are replaced. 391 return WalkResult::interrupt(); 392 }; 393 while (funcOp.walk(callback).wasInterrupted()) { 394 } 395 return success(); 396 } 397 }; 398 } // namespace 399 400 void mlir::populateGpuRewritePatterns(MLIRContext *context, 401 OwningRewritePatternList &patterns) { 402 patterns.insert<GpuAllReduceConversion>(context); 403 } 404