1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements in-dialect lowering of the all-reduce op to a block of 10 // simpler instructions. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/GPU/Passes.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/StandardOps/IR/Ops.h" 19 #include "mlir/IR/BlockAndValueMapping.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/PatternMatch.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace mlir; 25 26 namespace { 27 28 struct GpuAllReduceRewriter { 29 using AccumulatorFactory = std::function<Value(Value, Value)>; 30 31 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp, gpu::AllReduceOp reduceOp, 32 PatternRewriter &rewriter) 33 : funcOp(funcOp), reduceOp(reduceOp), rewriter(rewriter), 34 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), 35 indexType(IndexType::get(reduceOp.getContext())), 36 int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {} 37 38 /// Creates an all_reduce across the workgroup. 39 /// 40 /// First reduce the elements within a subgroup. The first invocation of each 41 /// subgroup writes the intermediate result to workgroup memory. After 42 /// synchronizing the workgroup, the first subgroup reduces the values from 43 /// workgroup memory. The result is broadcasted to all invocations through 44 /// workgroup memory. 45 /// 46 /// %subgroup_reduce = `createSubgroupReduce(%operand)` 47 /// cond_br %is_first_lane, ^then1, ^continue1 48 /// ^then1: 49 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] 50 /// br ^continue1 51 /// ^continue1: 52 /// gpu.barrier 53 /// %is_valid_subgroup = arith.cmpi "slt" %invocation_idx, %num_subgroups 54 /// cond_br %is_valid_subgroup, ^then2, ^continue2 55 /// ^then2: 56 /// %partial_reduce = load %workgroup_buffer[%invocation_idx] 57 /// %all_reduce = `createSubgroupReduce(%partial_reduce)` 58 /// store %all_reduce, %workgroup_buffer[%zero] 59 /// llvm.br ^continue2 60 /// ^continue2: 61 /// gpu.barrier 62 /// %result = load %workgroup_buffer[%zero] 63 /// return %result 64 /// 65 void rewrite() { 66 rewriter.setInsertionPoint(reduceOp); 67 68 // Compute linear invocation index and workgroup size. 69 Value dimX = getDimOp<gpu::BlockDimOp>(gpu::Dimension::x); 70 Value dimY = getDimOp<gpu::BlockDimOp>(gpu::Dimension::y); 71 Value dimZ = getDimOp<gpu::BlockDimOp>(gpu::Dimension::z); 72 Value tidX = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::x); 73 Value tidY = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::y); 74 Value tidZ = getDimOp<gpu::ThreadIdOp>(gpu::Dimension::z); 75 Value tmp1 = create<arith::MulIOp>(int32Type, tidZ, dimY); 76 Value tmp2 = create<arith::AddIOp>(int32Type, tmp1, tidY); 77 Value tmp3 = create<arith::MulIOp>(int32Type, tmp2, dimX); 78 Value tmp4 = create<arith::MulIOp>(int32Type, dimX, dimY); 79 Value invocationIdx = create<arith::AddIOp>(int32Type, tmp3, tidX); 80 Value workgroupSize = create<arith::MulIOp>(int32Type, tmp4, dimZ); 81 82 // Compute lane id (invocation id withing the subgroup). 83 Value subgroupMask = 84 create<arith::ConstantIntOp>(kSubgroupSize - 1, int32Type); 85 Value laneId = create<arith::AndIOp>(invocationIdx, subgroupMask); 86 Value isFirstLane = 87 create<arith::CmpIOp>(arith::CmpIPredicate::eq, laneId, 88 create<arith::ConstantIntOp>(0, int32Type)); 89 90 Value numThreadsWithSmallerSubgroupId = 91 create<arith::SubIOp>(invocationIdx, laneId); 92 // The number of active invocations starting from the current subgroup. 93 // The consumers do not require the value to be clamped to the size of the 94 // subgroup. 95 Value activeWidth = 96 create<arith::SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId); 97 98 // Create factory for op which accumulates to values. 99 AccumulatorFactory accumFactory = getFactory(); 100 assert(accumFactory && "failed to create accumulator factory"); 101 102 // Reduce elements within each subgroup to produce the intermediate results. 103 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, 104 reduceOp.value(), accumFactory); 105 106 // Add workgroup buffer to parent function for intermediate result. 107 Value buffer = createWorkgroupBuffer(); 108 109 // Write the intermediate results to workgroup memory, using the first lane 110 // of each subgroup. 111 createPredicatedBlock(isFirstLane, [&] { 112 Value subgroupId = getDivideBySubgroupSize(invocationIdx); 113 Value index = create<arith::IndexCastOp>(indexType, subgroupId); 114 create<memref::StoreOp>(subgroupReduce, buffer, index); 115 }); 116 create<gpu::BarrierOp>(); 117 118 // Compute number of active subgroups. 119 Value biasedBlockSize = 120 create<arith::AddIOp>(int32Type, workgroupSize, subgroupMask); 121 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); 122 Value isValidSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 123 invocationIdx, numSubgroups); 124 125 // Use the first numSubgroups invocations to reduce the intermediate results 126 // from workgroup memory. The final result is written to workgroup memory 127 // again. 128 Value zero = create<arith::ConstantIndexOp>(0); 129 createPredicatedBlock(isValidSubgroup, [&] { 130 Value index = create<arith::IndexCastOp>(indexType, invocationIdx); 131 Value value = create<memref::LoadOp>(valueType, buffer, index); 132 Value result = 133 createSubgroupReduce(numSubgroups, laneId, value, accumFactory); 134 create<memref::StoreOp>(result, buffer, zero); 135 }); 136 137 // Synchronize workgroup and load result from workgroup memory. 138 create<gpu::BarrierOp>(); 139 Value result = create<memref::LoadOp>(valueType, buffer, zero); 140 141 rewriter.replaceOp(reduceOp, result); 142 } 143 144 private: 145 // Shortcut to create an op from rewriter using loc as the first argument. 146 template <typename T, typename... Args> 147 T create(Args... args) { 148 return rewriter.create<T>(loc, std::forward<Args>(args)...); 149 } 150 151 // Creates dimension op of type T, with the result casted to int32. 152 template <typename T> 153 Value getDimOp(gpu::Dimension dimension) { 154 Value dim = create<T>(indexType, dimension); 155 return create<arith::IndexCastOp>(int32Type, dim); 156 } 157 158 /// Adds type to funcOp's workgroup attributions. 159 Value createWorkgroupBuffer() { 160 int workgroupMemoryAddressSpace = 161 gpu::GPUDialect::getWorkgroupAddressSpace(); 162 auto bufferType = MemRefType::get({kSubgroupSize}, valueType, AffineMap{}, 163 workgroupMemoryAddressSpace); 164 return funcOp.addWorkgroupAttribution(bufferType); 165 } 166 167 /// Returns an accumulator factory using either the op attribute or the body 168 /// region. 169 AccumulatorFactory getFactory() { 170 auto &body = reduceOp.body(); 171 if (!body.empty()) 172 return getFactory(body); 173 auto opAttr = reduceOp.op(); 174 if (opAttr) 175 return getFactory(*opAttr); 176 return AccumulatorFactory(); 177 } 178 179 /// Returns an accumulator factory that clones the body. The body's entry 180 /// block is expected to have 2 arguments. The gpu.yield return the 181 /// accumulated value of the same type. 182 AccumulatorFactory getFactory(Region &body) { 183 return AccumulatorFactory([&](Value lhs, Value rhs) { 184 Block *block = rewriter.getInsertionBlock(); 185 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); 186 187 // Insert accumulator body between split block. 188 BlockAndValueMapping mapping; 189 mapping.map(body.getArgument(0), lhs); 190 mapping.map(body.getArgument(1), rhs); 191 rewriter.cloneRegionBefore(body, *split->getParent(), 192 split->getIterator(), mapping); 193 194 // Add branch before inserted body, into body. 195 block = block->getNextNode(); 196 create<BranchOp>(block, ValueRange()); 197 198 // Replace all gpu.yield ops with branch out of body. 199 for (; block != split; block = block->getNextNode()) { 200 Operation *terminator = block->getTerminator(); 201 if (!isa<gpu::YieldOp>(terminator)) 202 continue; 203 rewriter.setInsertionPointToEnd(block); 204 rewriter.replaceOpWithNewOp<BranchOp>( 205 terminator, split, ValueRange(terminator->getOperand(0))); 206 } 207 208 // Return accumulator result. 209 rewriter.setInsertionPointToStart(split); 210 return split->addArgument(lhs.getType()); 211 }); 212 } 213 214 /// Returns an accumulator factory that creates an op specified by opName. 215 AccumulatorFactory getFactory(gpu::AllReduceOperation opName) { 216 bool isFloatingPoint = valueType.isa<FloatType>(); 217 switch (opName) { 218 case gpu::AllReduceOperation::ADD: 219 return isFloatingPoint ? getFactory<arith::AddFOp>() 220 : getFactory<arith::AddIOp>(); 221 case gpu::AllReduceOperation::MUL: 222 return isFloatingPoint ? getFactory<arith::MulFOp>() 223 : getFactory<arith::MulIOp>(); 224 case gpu::AllReduceOperation::AND: 225 return getFactory<arith::AndIOp>(); 226 case gpu::AllReduceOperation::OR: 227 return getFactory<arith::OrIOp>(); 228 case gpu::AllReduceOperation::XOR: 229 return getFactory<arith::XOrIOp>(); 230 case gpu::AllReduceOperation::MAX: 231 return isFloatingPoint 232 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 233 arith::CmpFPredicate::UGT>() 234 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 235 arith::CmpIPredicate::ugt>(); 236 case gpu::AllReduceOperation::MIN: 237 return isFloatingPoint 238 ? getCmpFactory<arith::CmpFOp, arith::CmpFPredicate, 239 arith::CmpFPredicate::ULT>() 240 : getCmpFactory<arith::CmpIOp, arith::CmpIPredicate, 241 arith::CmpIPredicate::ult>(); 242 } 243 } 244 245 /// Returns an accumulator factory that creates an op of type T. 246 template <typename T> 247 AccumulatorFactory getFactory() { 248 return [&](Value lhs, Value rhs) { 249 return create<T>(lhs.getType(), lhs, rhs); 250 }; 251 } 252 253 /// Returns an accumulator for comparison such as min, max. T is the type 254 /// of the compare op. 255 template <typename T, typename PredicateEnum, PredicateEnum predicate> 256 AccumulatorFactory getCmpFactory() const { 257 return [&](Value lhs, Value rhs) { 258 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs); 259 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs); 260 }; 261 } 262 263 /// Creates an if-block skeleton and calls the two factories to generate the 264 /// ops in the `then` and `else` block.. 265 /// 266 /// llvm.cond_br %condition, ^then, ^continue 267 /// ^then: 268 /// %then_operands = `thenOpsFactory()` 269 /// llvm.br ^continue(%then_operands) 270 /// ^else: 271 /// %else_operands = `elseOpsFactory()` 272 /// llvm.br ^continue(%else_operands) 273 /// ^continue(%block_operands): 274 /// 275 template <typename ThenOpsFactory, typename ElseOpsFactory> 276 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, 277 ElseOpsFactory &&elseOpsFactory) { 278 Block *currentBlock = rewriter.getInsertionBlock(); 279 auto currentPoint = rewriter.getInsertionPoint(); 280 281 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); 282 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); 283 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); 284 285 rewriter.setInsertionPointToEnd(currentBlock); 286 create<CondBranchOp>(condition, thenBlock, 287 /*trueOperands=*/ArrayRef<Value>(), elseBlock, 288 /*falseOperands=*/ArrayRef<Value>()); 289 290 rewriter.setInsertionPointToStart(thenBlock); 291 auto thenOperands = thenOpsFactory(); 292 create<BranchOp>(continueBlock, thenOperands); 293 294 rewriter.setInsertionPointToStart(elseBlock); 295 auto elseOperands = elseOpsFactory(); 296 create<BranchOp>(continueBlock, elseOperands); 297 298 assert(thenOperands.size() == elseOperands.size()); 299 rewriter.setInsertionPointToStart(continueBlock); 300 for (auto operand : thenOperands) 301 continueBlock->addArgument(operand.getType()); 302 } 303 304 /// Shortcut for createIf with empty else block and no block operands. 305 template <typename Factory> 306 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { 307 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value, 308 "predicatedOpsFactory should not return any value"); 309 createIf( 310 condition, 311 [&] { 312 predicatedOpsFactory(); 313 return ArrayRef<Value>(); 314 }, 315 [&] { return ArrayRef<Value>(); }); 316 } 317 318 /// Creates a reduction across the first activeWidth lanes of a subgroup, or 319 /// the entire subgroup if activeWidth is larger than the subgroup width. 320 /// The first lane returns the result, all others return values are undefined. 321 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, 322 AccumulatorFactory &accumFactory) { 323 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 324 Value isPartialSubgroup = create<arith::CmpIOp>(arith::CmpIPredicate::slt, 325 activeWidth, subgroupSize); 326 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()}; 327 328 createIf( 329 isPartialSubgroup, 330 // Generate reduction over a (potentially) partial subgroup. 331 [&] { 332 Value value = operand; 333 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source 334 // lane is within the active range. The accumulated value is available 335 // in the first lane. 336 for (int i = 1; i < kSubgroupSize; i <<= 1) { 337 Value offset = create<arith::ConstantIntOp>(i, int32Type); 338 auto shuffleOp = create<gpu::ShuffleOp>( 339 shuffleType, value, offset, activeWidth, gpu::ShuffleMode::XOR); 340 // Skip the accumulation if the shuffle op read from a lane outside 341 // of the active range. 342 createIf( 343 shuffleOp.getResult(1), 344 [&] { 345 return SmallVector<Value, 1>{ 346 accumFactory(value, shuffleOp.getResult(0))}; 347 }, 348 [&] { return llvm::makeArrayRef(value); }); 349 value = rewriter.getInsertionBlock()->getArgument(0); 350 } 351 return SmallVector<Value, 1>{value}; 352 }, 353 // Generate a reduction over the entire subgroup. This is a 354 // specialization of the above reduction with unconditional 355 // accumulation. 356 [&] { 357 Value value = operand; 358 for (int i = 1; i < kSubgroupSize; i <<= 1) { 359 Value offset = create<arith::ConstantIntOp>(i, int32Type); 360 auto shuffleOp = 361 create<gpu::ShuffleOp>(shuffleType, value, offset, subgroupSize, 362 gpu::ShuffleMode::XOR); 363 value = accumFactory(value, shuffleOp.getResult(0)); 364 } 365 return SmallVector<Value, 1>{value}; 366 }); 367 return rewriter.getInsertionBlock()->getArgument(0); 368 } 369 370 /// Returns value divided by the subgroup size (i.e. 32). 371 Value getDivideBySubgroupSize(Value value) { 372 Value subgroupSize = create<arith::ConstantIntOp>(kSubgroupSize, int32Type); 373 return create<arith::DivSIOp>(int32Type, value, subgroupSize); 374 } 375 376 gpu::GPUFuncOp funcOp; 377 gpu::AllReduceOp reduceOp; 378 PatternRewriter &rewriter; 379 380 Location loc; 381 Type valueType; 382 Type indexType; 383 IntegerType int32Type; 384 385 static constexpr int kSubgroupSize = 32; 386 }; 387 388 struct GpuAllReduceConversion : public RewritePattern { 389 explicit GpuAllReduceConversion(MLIRContext *context) 390 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} 391 392 LogicalResult matchAndRewrite(Operation *op, 393 PatternRewriter &rewriter) const override { 394 auto funcOp = cast<gpu::GPUFuncOp>(op); 395 auto callback = [&](gpu::AllReduceOp reduceOp) { 396 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); 397 // Performing a rewrite invalidates the walk iterator. Report interrupt 398 // so that we can start a new walk until all all_reduce ops are replaced. 399 return WalkResult::interrupt(); 400 }; 401 while (funcOp.walk(callback).wasInterrupted()) { 402 } 403 return success(); 404 } 405 }; 406 } // namespace 407 408 void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) { 409 patterns.add<GpuAllReduceConversion>(patterns.getContext()); 410 } 411