1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/GPU/Passes.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace mlir; 25 26 namespace { 27 28 struct GpuAllReduceRewriter { 29 using AccumulatorFactory = std::function<Value(Value, Value)>; 30 31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, 32 PatternRewriter &rewriter_) 33 : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), 34 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 35 indexType(IndexType::get(reduceOp.getContext())), 36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} 37 38 /// Creates an all_reduce across the workgroup. 39 /// 40 /// First reduce the elements within a subgroup. The first invocation of each 41 /// subgroup writes the intermediate result to workgroup memory. After 42 /// synchronizing the workgroup, the first subgroup reduces the values from 43 /// workgroup memory. The result is broadcasted to all invocations through 44 /// workgroup memory. 45 /// 46 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 47 /// cond_br %is_first_lane, ^then1, ^continue1 48 /// ^then1: 49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 50 /// br ^continue1 51 /// ^continue1: 52 /// gpu.barrier 53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups 54 /// cond_br %is_valid_subgroup, ^then2, ^continue2 55 /// ^then2: 56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 58 /// store %all_reduce, %workgroup_buffer[%zero] 59 /// llvm.br ^continue2 60 /// ^continue2: 61 /// gpu.barrier 62 /// %result = load %workgroup_buffer[%zero] 63 /// return %result 64 /// 65 void rewrite() { 66 rewriter.setInsertionPoint(reduceOp); 67 68 // Compute linear invocation index and workgroup size. 69 Value dimX = getDimOp<gpu::BlockDimOp>("x"); 70 Value dimY = getDimOp<gpu::BlockDimOp>("y"); 71 Value dimZ = getDimOp<gpu::BlockDimOp>("z"); 72 Value tidX = getDimOp<gpu::ThreadIdOp>("x"); 73 Value tidY = getDimOp<gpu::ThreadIdOp>("y"); 74 Value tidZ = getDimOp<gpu::ThreadIdOp>("z"); 75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); 76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); 77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); 78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); 79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); 80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); 81 82 // Compute lane id (invocation id withing the subgroup). 83 Value subgroupMask = 84 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); 85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); 86 Value isFirstLane = 87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, 88 create<arith::ConstantIntOp>(0, int32Type)); 89 90 Value numThreadsWithSmallerSubgroupId = 91 create<arith::SubIOp>(invocationIdx, laneId); 92 // The number of active invocations starting from the current subgroup. 93 // The consumers do not require the value to be clamped to the size of the 94 // subgroup. 95 Value activeWidth = 96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 97 98 // Create factory for op which accumulates to values. 99 AccumulatorFactory accumFactory = getFactory(); 100 assert(accumFactory && "failed to create accumulator factory"); 101 102 // Reduce elements within each subgroup to produce the intermediate results. 103 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 104 reduceOp.value(), accumFactory); 105 106 // Add workgroup buffer to parent function for intermediate result. 107 Value buffer = createWorkgroupBuffer(); 108 109 // Write the intermediate results to workgroup memory, using the first lane 110 // of each subgroup. 111 createPredicatedBlock(isFirstLane, [&] { 112 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 113 Value index = create<arith::IndexCastOp>(indexType, subgroupId); 114 create<memref::StoreOp>(subgroupReduce, buffer, index); 115 }); 116 create<gpu::BarrierOp>(); 117 118 // Compute number of active subgroups. 119 Value biasedBlockSize = 120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); 121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 123 invocationIdx, numSubgroups); 124 125 // Use the first numSubgroups invocations to reduce the intermediate results 126 // from workgroup memory. The final result is written to workgroup memory 127 // again. 128 Value zero = create<arith::ConstantIndexOp>(0); 129 createPredicatedBlock(isValidSubgroup, [&] { 130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx); 131 Value value = create<memref::LoadOp>(valueType, buffer, index); 132 Value result = 133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 134 create<memref::StoreOp>(result, buffer, zero); 135 }); 136 137 // Synchronize workgroup and load result from workgroup memory. 138 create<gpu::BarrierOp>(); 139 Value result = create<memref::LoadOp>(valueType, buffer, zero); 140 141 rewriter.replaceOp(reduceOp, result); 142 } 143 144 private: 145 // Shortcut to create an op from rewriter using loc as the first argument. 146 template <typename T, typename... Args> 147 T create(Args... args) { 148 return rewriter.create<T>(loc, std::forward<Args>(args)...); 149 } 150 151 // Creates dimension op of type T, with the result casted to int32. 152 template <typename T> 153 Value getDimOp(StringRef dimension) { 154 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension)); 155 return create<arith::IndexCastOp>(int32Type, dim); 156 } 157 158 /// Adds type to funcOp's workgroup attributions. 159 Value createWorkgroupBuffer() { 160 int workgroupMemoryAddressSpace = 161 gpu::GPUDialect::getWorkgroupAddressSpace(); 162 auto bufferType = 163 MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{}, 164 workgroupMemoryAddressSpace); 165 return funcOp.addWorkgroupAttribution(bufferType); 166 } 167 168 /// Returns an accumulator factory using either the op attribute or the body 169 /// region. 170 AccumulatorFactory getFactory() { 171 auto &body = reduceOp.body(); 172 if (!body.empty()) 173 return getFactory(body); 174 auto opAttr = reduceOp.op(); 175 if (opAttr) 176 return getFactory(*opAttr); 177 return AccumulatorFactory(); 178 } 179 180 /// Returns an accumulator factory that clones the body. The body's entry 181 /// block is expected to have 2 arguments. The gpu.yield return the 182 /// accumulated value of the same type. 183 AccumulatorFactory getFactory(Region &body) { 184 return AccumulatorFactory([&](Value lhs, Value rhs) { 185 Block *block = rewriter.getInsertionBlock(); 186 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 187 188 // Insert accumulator body between split block. 189 BlockAndValueMapping mapping; 190 mapping.map(body.getArgument(0), lhs); 191 mapping.map(body.getArgument(1), rhs); 192 rewriter.cloneRegionBefore(body, *split->getParent(), 193 split->getIterator(), mapping); 194 195 // Add branch before inserted body, into body. 196 block = block->getNextNode(); 197 create<BranchOp>(block, ValueRange()); 198 199 // Replace all gpu.yield ops with branch out of body. 200 for (; block != split; block = block->getNextNode()) { 201 Operation *terminator = block->getTerminator(); 202 if (!isa<gpu::YieldOp>(terminator)) 203 continue; 204 rewriter.setInsertionPointToEnd(block); 205 rewriter.replaceOpWithNewOp<BranchOp>( 206 terminator, split, ValueRange(terminator->getOperand(0))); 207 } 208 209 // Return accumulator result. 210 rewriter.setInsertionPointToStart(split); 211 return split->addArgument(lhs.getType()); 212 }); 213 } 214 215 /// Returns an accumulator factory that creates an op specified by opName. 216 AccumulatorFactory getFactory(StringRef opName) { 217 bool isFloatingPoint = valueType.isa<FloatType>(); 218 if (opName == "add") 219 return isFloatingPoint ? getFactory<arith::AddFOp>() 220 : getFactory<arith::AddIOp>(); 221 if (opName == "mul") 222 return isFloatingPoint ? getFactory<arith::MulFOp>() 223 : getFactory<arith::MulIOp>(); 224 if (opName == "and") { 225 return getFactory<arith::AndIOp>(); 226 } 227 if (opName == "or") { 228 return getFactory<arith::OrIOp>(); 229 } 230 if (opName == "xor") { 231 return getFactory<arith::XOrIOp>(); 232 } 233 if (opName == "max") { 234 return isFloatingPoint 235 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 236 arith::CmpFPredicate::UGT>() 237 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 238 arith::CmpIPredicate::ugt>(); 239 } 240 if (opName == "min") { 241 return isFloatingPoint 242 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 243 arith::CmpFPredicate::ULT>() 244 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 245 arith::CmpIPredicate::ult>(); 246 } 247 return AccumulatorFactory(); 248 } 249 250 /// Returns an accumulator factory that creates an op of type T. 251 template <typename T> 252 AccumulatorFactory getFactory() { 253 return [&](Value lhs, Value rhs) { 254 return create<T>(lhs.getType(), lhs, rhs); 255 }; 256 } 257 258 /// Returns an accumulator for comparison such as min, max. T is the type 259 /// of the compare op. 260 template <typename T, typename PredicateEnum, PredicateEnum predicate> 261 AccumulatorFactory getCmpFactory() const { 262 return [&](Value lhs, Value rhs) { 263 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 264 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 265 }; 266 } 267 268 /// Creates an if-block skeleton and calls the two factories to generate the 269 /// ops in the `then` and `else` block.. 270 /// 271 /// llvm.cond_br %condition, ^then, ^continue 272 /// ^then: 273 /// %then_operands = `thenOpsFactory()` 274 /// llvm.br ^continue(%then_operands) 275 /// ^else: 276 /// %else_operands = `elseOpsFactory()` 277 /// llvm.br ^continue(%else_operands) 278 /// ^continue(%block_operands): 279 /// 280 template <typename ThenOpsFactory, typename ElseOpsFactory> 281 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 282 ElseOpsFactory &&elseOpsFactory) { 283 Block *currentBlock = rewriter.getInsertionBlock(); 284 auto currentPoint = rewriter.getInsertionPoint(); 285 286 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 287 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 288 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 289 290 rewriter.setInsertionPointToEnd(currentBlock); 291 create<CondBranchOp>(condition, thenBlock, 292 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 293 /*falseOperands=*/ArrayRef<Value>()); 294 295 rewriter.setInsertionPointToStart(thenBlock); 296 auto thenOperands = thenOpsFactory(); 297 create<BranchOp>(continueBlock, thenOperands); 298 299 rewriter.setInsertionPointToStart(elseBlock); 300 auto elseOperands = elseOpsFactory(); 301 create<BranchOp>(continueBlock, elseOperands); 302 303 assert(thenOperands.size() == elseOperands.size()); 304 rewriter.setInsertionPointToStart(continueBlock); 305 for (auto operand : thenOperands) 306 continueBlock->addArgument(operand.getType()); 307 } 308 309 /// Shortcut for createIf with empty else block and no block operands. 310 template <typename Factory> 311 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 312 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 313 "predicatedOpsFactory should not return any value"); 314 createIf( 315 condition, 316 [&] { 317 predicatedOpsFactory(); 318 return ArrayRef<Value>(); 319 }, 320 [&] { return ArrayRef<Value>(); }); 321 } 322 323 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 324 /// the entire subgroup if activeWidth is larger than the subgroup width. 325 /// The first lane returns the result, all others return values are undefined. 326 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 327 AccumulatorFactory &accumFactory) { 328 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 329 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 330 activeWidth, subgroupSize); 331 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 332 auto xorAttr = rewriter.getStringAttr("xor"); 333 334 createIf( 335 isPartialSubgroup, 336 // Generate reduction over a (potentially) partial subgroup. 337 [&] { 338 Value value = operand; 339 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 340 // lane is within the active range. The accumulated value is available 341 // in the first lane. 342 for (int i = 1; i < kSubgroupSize; i <<= 1) { 343 Value offset = create<arith::ConstantIntOp>(i, int32Type); 344 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 345 activeWidth, xorAttr); 346 // Skip the accumulation if the shuffle op read from a lane outside 347 // of the active range. 348 createIf( 349 shuffleOp.getResult(1), 350 [&] { 351 return SmallVector<Value, 1>{ 352 accumFactory(value, shuffleOp.getResult(0))}; 353 }, 354 [&] { return llvm::makeArrayRef(value); }); 355 value = rewriter.getInsertionBlock()->getArgument(0); 356 } 357 return SmallVector<Value, 1>{value}; 358 }, 359 // Generate a reduction over the entire subgroup. This is a 360 // specialization of the above reduction with unconditional 361 // accumulation. 362 [&] { 363 Value value = operand; 364 for (int i = 1; i < kSubgroupSize; i <<= 1) { 365 Value offset = create<arith::ConstantIntOp>(i, int32Type); 366 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset, 367 subgroupSize, xorAttr); 368 value = accumFactory(value, shuffleOp.getResult(0)); 369 } 370 return SmallVector<Value, 1>{value}; 371 }); 372 return rewriter.getInsertionBlock()->getArgument(0); 373 } 374 375 /// Returns value divided by the subgroup size (i.e. 32). 376 Value getDivideBySubgroupSize(Value value) { 377 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 378 return create<arith::DivSIOp>(int32Type, value, subgroupSize); 379 } 380 381 gpu::GPUFuncOp funcOp; 382 gpu::AllReduceOp reduceOp; 383 PatternRewriter &rewriter; 384 385 Location loc; 386 Type valueType; 387 Type indexType; 388 IntegerType int32Type; 389 390 static constexpr int kSubgroupSize = 32; 391 }; 392 393 struct GpuAllReduceConversion : public RewritePattern { 394 explicit GpuAllReduceConversion(MLIRContext *context) 395 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 396 397 LogicalResult matchAndRewrite(Operation *op, 398 PatternRewriter &rewriter) const override { 399 auto funcOp = cast<gpu::GPUFuncOp>(op); 400 auto callback = [&](gpu::AllReduceOp reduceOp) { 401 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 402 // Performing a rewrite invalidates the walk iterator. Report interrupt 403 // so that we can start a new walk until all all_reduce ops are replaced. 404 return WalkResult::interrupt(); 405 }; 406 while (funcOp.walk(callback).wasInterrupted()) { 407 } 408 return success(); 409 } 410 }; 411 } // namespace 412 413 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { 414 patterns.add<GpuAllReduceConversion>(patterns.getContext()); 415 } 416