1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 12 #include "mlir/Dialect/MemRef/IR/MemRef.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 15 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 16 #include "mlir/IR/BlockAndValueMapping.h" 17 #include "mlir/Transforms/SideEffectUtils.h" 18 19 using namespace mlir; 20 using namespace mlir::vector; 21 22 static LogicalResult 23 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 24 const WarpExecuteOnLane0LoweringOptions &options) { 25 assert(warpOp.getBodyRegion().hasOneBlock() && 26 "expected WarpOp with single block"); 27 Block *warpOpBody = &warpOp.getBodyRegion().front(); 28 Location loc = warpOp.getLoc(); 29 30 // Passed all checks. Start rewriting. 31 OpBuilder::InsertionGuard g(rewriter); 32 rewriter.setInsertionPoint(warpOp); 33 34 // Create scf.if op. 35 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 36 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 37 warpOp.getLaneid(), c0); 38 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 39 /*withElseRegion=*/false); 40 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 41 42 // Store vectors that are defined outside of warpOp into the scratch pad 43 // buffer. 44 SmallVector<Value> bbArgReplacements; 45 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 46 Value val = it.value(); 47 Value bbArg = warpOpBody->getArgument(it.index()); 48 49 rewriter.setInsertionPoint(ifOp); 50 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 51 bbArg.getType()); 52 53 // Store arg vector into buffer. 54 rewriter.setInsertionPoint(ifOp); 55 auto vectorType = val.getType().cast<VectorType>(); 56 int64_t storeSize = vectorType.getShape()[0]; 57 Value storeOffset = rewriter.create<arith::MulIOp>( 58 loc, warpOp.getLaneid(), 59 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 60 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 61 62 // Load bbArg vector from buffer. 63 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 64 auto bbArgType = bbArg.getType().cast<VectorType>(); 65 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 66 bbArgReplacements.push_back(loadOp); 67 } 68 69 // Insert sync after all the stores and before all the loads. 70 if (!warpOp.getArgs().empty()) { 71 rewriter.setInsertionPoint(ifOp); 72 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 73 } 74 75 // Move body of warpOp to ifOp. 76 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 77 78 // Rewrite terminator and compute replacements of WarpOp results. 79 SmallVector<Value> replacements; 80 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 81 Location yieldLoc = yieldOp.getLoc(); 82 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 83 Value val = it.value(); 84 Type resultType = warpOp->getResultTypes()[it.index()]; 85 rewriter.setInsertionPoint(ifOp); 86 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 87 val.getType()); 88 89 // Store yielded value into buffer. 90 rewriter.setInsertionPoint(yieldOp); 91 if (val.getType().isa<VectorType>()) 92 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 93 else 94 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 95 96 // Load value from buffer (after warpOp). 97 rewriter.setInsertionPointAfter(ifOp); 98 if (resultType == val.getType()) { 99 // Result type and yielded value type are the same. This is a broadcast. 100 // E.g.: 101 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 102 // vector.yield %cst : f32 103 // } 104 // Both types are f32. The constant %cst is broadcasted to all lanes. 105 // This is described in more detail in the documentation of the op. 106 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 107 replacements.push_back(loadOp); 108 } else { 109 auto loadedVectorType = resultType.cast<VectorType>(); 110 int64_t loadSize = loadedVectorType.getShape()[0]; 111 112 // loadOffset = laneid * loadSize 113 Value loadOffset = rewriter.create<arith::MulIOp>( 114 loc, warpOp.getLaneid(), 115 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 116 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 117 buffer, loadOffset); 118 replacements.push_back(loadOp); 119 } 120 } 121 122 // Insert sync after all the stores and before all the loads. 123 if (!yieldOp.operands().empty()) { 124 rewriter.setInsertionPointAfter(ifOp); 125 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 126 } 127 128 // Delete terminator and add empty scf.yield. 129 rewriter.eraseOp(yieldOp); 130 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 131 rewriter.create<scf::YieldOp>(yieldLoc); 132 133 // Compute replacements for WarpOp results. 134 rewriter.replaceOp(warpOp, replacements); 135 136 return success(); 137 } 138 139 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 140 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 141 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 142 ValueRange newYieldedValues, TypeRange newReturnTypes) { 143 // Create a new op before the existing one, with the extra operands. 144 OpBuilder::InsertionGuard g(rewriter); 145 rewriter.setInsertionPoint(warpOp); 146 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 147 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 148 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 149 150 Region &opBody = warpOp.getBodyRegion(); 151 Region &newOpBody = newWarpOp.getBodyRegion(); 152 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 153 auto yield = 154 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 155 156 rewriter.updateRootInPlace( 157 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 158 return newWarpOp; 159 } 160 161 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 162 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 163 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 164 ValueRange newYieldedValues, TypeRange newReturnTypes) { 165 SmallVector<Type> types(warpOp.getResultTypes().begin(), 166 warpOp.getResultTypes().end()); 167 types.append(newReturnTypes.begin(), newReturnTypes.end()); 168 auto yield = cast<vector::YieldOp>( 169 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 170 SmallVector<Value> yieldValues(yield.getOperands().begin(), 171 yield.getOperands().end()); 172 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 173 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 174 rewriter, warpOp, yieldValues, types); 175 rewriter.replaceOp(warpOp, 176 newWarpOp.getResults().take_front(warpOp.getNumResults())); 177 return newWarpOp; 178 } 179 180 /// Helper to know if an op can be hoisted out of the region. 181 static bool canBeHoisted(Operation *op, 182 function_ref<bool(Value)> definedOutside) { 183 return llvm::all_of(op->getOperands(), definedOutside) && 184 isSideEffectFree(op) && op->getNumRegions() == 0; 185 } 186 187 /// Return a value yielded by `warpOp` which statifies the filter lamdba 188 /// condition and is not dead. 189 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 190 std::function<bool(Operation *)> fn) { 191 auto yield = cast<vector::YieldOp>( 192 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 193 for (OpOperand &yieldOperand : yield->getOpOperands()) { 194 Value yieldValues = yieldOperand.get(); 195 Operation *definedOp = yieldValues.getDefiningOp(); 196 if (definedOp && fn(definedOp)) { 197 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 198 return &yieldOperand; 199 } 200 } 201 return {}; 202 } 203 204 // Clones `op` into a new operation that takes `operands` and returns 205 // `resultTypes`. 206 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 207 Location loc, Operation *op, 208 ArrayRef<Value> operands, 209 ArrayRef<Type> resultTypes) { 210 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 211 op->getAttrs()); 212 return rewriter.create(res); 213 } 214 215 /// Currently the distribution map is implicit based on the vector shape. In the 216 /// future it will be part of the op. 217 /// Example: 218 /// ``` 219 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 220 /// ... 221 /// vector.yield %3 : vector<32x16x64xf32> 222 /// } 223 /// ``` 224 /// Would have an implicit map of: 225 /// `(d0, d1, d2) -> (d0, d2)` 226 static AffineMap calculateImplicitMap(Value yield, Value ret) { 227 auto srcType = yield.getType().cast<VectorType>(); 228 auto dstType = ret.getType().cast<VectorType>(); 229 SmallVector<AffineExpr> perm; 230 // Check which dimensions of the yield value are different than the dimensions 231 // of the result to know the distributed dimensions. Then associate each 232 // distributed dimension to an ID in order. 233 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 234 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 235 perm.push_back(getAffineDimExpr(i, yield.getContext())); 236 } 237 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 238 return map; 239 } 240 241 namespace { 242 243 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 244 WarpOpToScfForPattern(MLIRContext *context, 245 const WarpExecuteOnLane0LoweringOptions &options, 246 PatternBenefit benefit = 1) 247 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 248 options(options) {} 249 250 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 251 PatternRewriter &rewriter) const override { 252 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 253 } 254 255 private: 256 const WarpExecuteOnLane0LoweringOptions &options; 257 }; 258 259 /// Distribute transfer_write ops based on the affine map returned by 260 /// `distributionMapFn`. 261 /// Example: 262 /// ``` 263 /// %0 = vector.warp_execute_on_lane_0(%id){ 264 /// ... 265 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 266 /// vector.yield 267 /// } 268 /// ``` 269 /// To 270 /// ``` 271 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 272 /// ... 273 /// vector.yield %v : vector<32xf32> 274 /// } 275 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 276 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 277 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 278 PatternBenefit b = 1) 279 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 280 distributionMapFn(fn) {} 281 282 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 283 /// are multiples of the distribution ratio are supported at the moment. 284 LogicalResult tryDistributeOp(RewriterBase &rewriter, 285 vector::TransferWriteOp writeOp, 286 WarpExecuteOnLane0Op warpOp) const { 287 AffineMap map = distributionMapFn(writeOp); 288 SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(), 289 writeOp.getVectorType().getShape().end()); 290 assert(map.getNumResults() == 1 && 291 "multi-dim distribution not implemented yet"); 292 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 293 unsigned position = map.getDimPosition(i); 294 if (targetShape[position] % warpOp.getWarpSize() != 0) 295 return failure(); 296 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 297 } 298 VectorType targetType = 299 VectorType::get(targetShape, writeOp.getVectorType().getElementType()); 300 301 SmallVector<Value> yieldValues = {writeOp.getVector()}; 302 SmallVector<Type> retTypes = {targetType}; 303 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 304 rewriter, warpOp, yieldValues, retTypes); 305 rewriter.setInsertionPointAfter(newWarpOp); 306 307 // Move op outside of region: Insert clone at the insertion point and delete 308 // the old op. 309 auto newWriteOp = 310 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 311 rewriter.eraseOp(writeOp); 312 313 rewriter.setInsertionPoint(newWriteOp); 314 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 315 Location loc = newWriteOp.getLoc(); 316 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 317 newWriteOp.getIndices().end()); 318 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 319 AffineExpr d0, d1; 320 bindDims(newWarpOp.getContext(), d0, d1); 321 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 322 if (!indexExpr) 323 continue; 324 unsigned indexPos = indexExpr.getPosition(); 325 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 326 auto scale = 327 getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); 328 indices[indexPos] = 329 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 330 {indices[indexPos], newWarpOp.getLaneid()}); 331 } 332 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 333 newWriteOp.getIndicesMutable().assign(indices); 334 335 return success(); 336 } 337 338 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 339 LogicalResult tryExtractOp(RewriterBase &rewriter, 340 vector::TransferWriteOp writeOp, 341 WarpExecuteOnLane0Op warpOp) const { 342 Location loc = writeOp.getLoc(); 343 VectorType vecType = writeOp.getVectorType(); 344 345 // Only vector<1x> is supported at the moment. 346 if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1) 347 return failure(); 348 349 // Do not process warp ops that contain only TransferWriteOps. 350 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 351 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 352 })) 353 return failure(); 354 355 SmallVector<Value> yieldValues = {writeOp.getVector()}; 356 SmallVector<Type> retTypes = {vecType}; 357 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 358 rewriter, warpOp, yieldValues, retTypes); 359 rewriter.setInsertionPointAfter(newWarpOp); 360 361 // Create a second warp op that contains only writeOp. 362 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 363 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 364 Block &body = secondWarpOp.getBodyRegion().front(); 365 rewriter.setInsertionPointToStart(&body); 366 auto newWriteOp = 367 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 368 newWriteOp.getVectorMutable().assign( 369 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 370 rewriter.eraseOp(writeOp); 371 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 372 return success(); 373 } 374 375 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 376 PatternRewriter &rewriter) const override { 377 // Ops with mask not supported yet. 378 if (writeOp.getMask()) 379 return failure(); 380 381 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 382 if (!warpOp) 383 return failure(); 384 385 // There must be no op with a side effect after writeOp. 386 Operation *nextOp = writeOp.getOperation(); 387 while ((nextOp = nextOp->getNextNode())) 388 if (!isSideEffectFree(nextOp)) 389 return failure(); 390 391 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 392 return writeOp.getVector() == value || 393 warpOp.isDefinedOutsideOfRegion(value); 394 })) 395 return failure(); 396 397 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 398 return success(); 399 400 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 401 return success(); 402 403 return failure(); 404 } 405 406 private: 407 DistributionMapFn distributionMapFn; 408 }; 409 410 /// Sink out elementwise op feeding into a warp op yield. 411 /// ``` 412 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 413 /// ... 414 /// %3 = arith.addf %1, %2 : vector<32xf32> 415 /// vector.yield %3 : vector<32xf32> 416 /// } 417 /// ``` 418 /// To 419 /// ``` 420 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 421 /// vector<1xf32>, vector<1xf32>) { 422 /// ... 423 /// %4 = arith.addf %2, %3 : vector<32xf32> 424 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 425 /// vector<32xf32> 426 /// } 427 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 428 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 429 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 430 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 431 PatternRewriter &rewriter) const override { 432 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 433 return OpTrait::hasElementwiseMappableTraits(op); 434 }); 435 if (!yieldOperand) 436 return failure(); 437 Operation *elementWise = yieldOperand->get().getDefiningOp(); 438 unsigned operandIndex = yieldOperand->getOperandNumber(); 439 Value distributedVal = warpOp.getResult(operandIndex); 440 SmallVector<Value> yieldValues; 441 SmallVector<Type> retTypes; 442 Location loc = warpOp.getLoc(); 443 for (OpOperand &operand : elementWise->getOpOperands()) { 444 Type targetType; 445 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 446 // If the result type is a vector, the operands must also be vectors. 447 auto operandType = operand.get().getType().cast<VectorType>(); 448 targetType = 449 VectorType::get(vecType.getShape(), operandType.getElementType()); 450 } else { 451 auto operandType = operand.get().getType(); 452 assert(!operandType.isa<VectorType>() && 453 "unexpected yield of vector from op with scalar result type"); 454 targetType = operandType; 455 } 456 retTypes.push_back(targetType); 457 yieldValues.push_back(operand.get()); 458 } 459 unsigned numResults = warpOp.getNumResults(); 460 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 461 rewriter, warpOp, yieldValues, retTypes); 462 rewriter.setInsertionPointAfter(newWarpOp); 463 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 464 elementWise->getOperands().end()); 465 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 466 newOperands[i] = newWarpOp.getResult(i + numResults); 467 } 468 OpBuilder::InsertionGuard g(rewriter); 469 rewriter.setInsertionPointAfter(newWarpOp); 470 Operation *newOp = cloneOpWithOperandsAndTypes( 471 rewriter, loc, elementWise, newOperands, 472 {newWarpOp.getResult(operandIndex).getType()}); 473 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 474 return success(); 475 } 476 }; 477 478 /// Sink out transfer_read op feeding into a warp op yield. 479 /// ``` 480 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 481 /// ... 482 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 483 // vector<32xf32> 484 /// vector.yield %2 : vector<32xf32> 485 /// } 486 /// ``` 487 /// To 488 /// ``` 489 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 490 /// vector<1xf32>, vector<1xf32>) { 491 /// ... 492 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 493 /// vector<32xf32> vector.yield %2 : vector<32xf32> 494 /// } 495 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 496 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 497 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 498 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 499 PatternRewriter &rewriter) const override { 500 OpOperand *operand = getWarpResult( 501 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 502 if (!operand) 503 return failure(); 504 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 505 unsigned operandIndex = operand->getOperandNumber(); 506 Value distributedVal = warpOp.getResult(operandIndex); 507 508 SmallVector<Value, 4> indices(read.getIndices().begin(), 509 read.getIndices().end()); 510 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 511 AffineMap indexMap = map.compose(read.getPermutationMap()); 512 OpBuilder::InsertionGuard g(rewriter); 513 rewriter.setInsertionPointAfter(warpOp); 514 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 515 AffineExpr d0, d1; 516 bindDims(read.getContext(), d0, d1); 517 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 518 if (!indexExpr) 519 continue; 520 unsigned indexPos = indexExpr.getPosition(); 521 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 522 int64_t scale = 523 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 524 indices[indexPos] = 525 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 526 {indices[indexPos], warpOp.getLaneid()}); 527 } 528 Value newRead = rewriter.create<vector::TransferReadOp>( 529 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 530 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 531 read.getInBoundsAttr()); 532 distributedVal.replaceAllUsesWith(newRead); 533 return success(); 534 } 535 }; 536 537 /// Remove any result that has no use along with the matching yieldOp operand. 538 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 539 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 540 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 541 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 542 PatternRewriter &rewriter) const override { 543 SmallVector<Type> resultTypes; 544 SmallVector<Value> yieldValues; 545 auto yield = cast<vector::YieldOp>( 546 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 547 for (OpResult result : warpOp.getResults()) { 548 if (result.use_empty()) 549 continue; 550 resultTypes.push_back(result.getType()); 551 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 552 } 553 if (yield.getNumOperands() == yieldValues.size()) 554 return failure(); 555 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 556 rewriter, warpOp, yieldValues, resultTypes); 557 unsigned resultIndex = 0; 558 for (OpResult result : warpOp.getResults()) { 559 if (result.use_empty()) 560 continue; 561 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 562 } 563 rewriter.eraseOp(warpOp); 564 return success(); 565 } 566 }; 567 568 // If an operand is directly yielded out of the region we can forward it 569 // directly and it doesn't need to go through the region. 570 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 571 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 572 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 573 PatternRewriter &rewriter) const override { 574 SmallVector<Type> resultTypes; 575 SmallVector<Value> yieldValues; 576 auto yield = cast<vector::YieldOp>( 577 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 578 Value valForwarded; 579 unsigned resultIndex; 580 for (OpOperand &operand : yield->getOpOperands()) { 581 Value result = warpOp.getResult(operand.getOperandNumber()); 582 if (result.use_empty()) 583 continue; 584 585 // Assume all the values coming from above are uniform. 586 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 587 if (result.getType() != operand.get().getType()) 588 continue; 589 valForwarded = operand.get(); 590 resultIndex = operand.getOperandNumber(); 591 break; 592 } 593 auto arg = operand.get().dyn_cast<BlockArgument>(); 594 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 595 continue; 596 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 597 if (result.getType() != warpOperand.getType()) 598 continue; 599 valForwarded = warpOperand; 600 resultIndex = operand.getOperandNumber(); 601 break; 602 } 603 if (!valForwarded) 604 return failure(); 605 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 606 return success(); 607 } 608 }; 609 610 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 611 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 612 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 613 PatternRewriter &rewriter) const override { 614 OpOperand *operand = getWarpResult( 615 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 616 if (!operand) 617 return failure(); 618 unsigned int operandNumber = operand->getOperandNumber(); 619 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 620 Location loc = broadcastOp.getLoc(); 621 auto destVecType = 622 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 623 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 624 rewriter, warpOp, {broadcastOp.getSource()}, 625 {broadcastOp.getSource().getType()}); 626 rewriter.setInsertionPointAfter(newWarpOp); 627 Value broadcasted = rewriter.create<vector::BroadcastOp>( 628 loc, destVecType, newWarpOp->getResults().back()); 629 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 630 631 return success(); 632 } 633 }; 634 635 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 636 /// the scf.ForOp is the last operation in the region so that it doesn't change 637 /// the order of execution. This creates a new scf.for region after the 638 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 639 /// WarpExecuteOnLane0Op region. Example: 640 /// ``` 641 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 642 /// ... 643 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 644 /// -> (vector<128xf32>) { 645 /// ... 646 /// scf.yield %r : vector<128xf32> 647 /// } 648 /// vector.yield %v1 : vector<128xf32> 649 /// } 650 /// ``` 651 /// To: 652 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 653 /// ... 654 /// vector.yield %v : vector<128xf32> 655 /// } 656 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 657 /// -> (vector<4xf32>) { 658 /// %iw = vector.warp_execute_on_lane_0(%laneid) 659 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 660 /// ^bb0(%arg: vector<128xf32>): 661 /// ... 662 /// vector.yield %ir : vector<128xf32> 663 /// } 664 /// scf.yield %iw : vector<4xf32> 665 /// } 666 /// ``` 667 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 668 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 669 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 670 PatternRewriter &rewriter) const override { 671 auto yield = cast<vector::YieldOp>( 672 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 673 // Only pick up forOp if it is the last op in the region. 674 Operation *lastNode = yield->getPrevNode(); 675 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 676 if (!forOp) 677 return failure(); 678 SmallVector<Value> newOperands; 679 SmallVector<unsigned> resultIdx; 680 // Collect all the outputs coming from the forOp. 681 for (OpOperand &yieldOperand : yield->getOpOperands()) { 682 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 683 continue; 684 auto forResult = yieldOperand.get().cast<OpResult>(); 685 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 686 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 687 resultIdx.push_back(yieldOperand.getOperandNumber()); 688 } 689 OpBuilder::InsertionGuard g(rewriter); 690 rewriter.setInsertionPointAfter(warpOp); 691 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 692 // inside. 693 auto newForOp = rewriter.create<scf::ForOp>( 694 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 695 forOp.getStep(), newOperands); 696 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 697 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 698 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 699 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 700 forOp.getResultTypes()); 701 702 SmallVector<Value> argMapping; 703 argMapping.push_back(newForOp.getInductionVar()); 704 for (Value args : innerWarp.getBody()->getArguments()) { 705 argMapping.push_back(args); 706 } 707 SmallVector<Value> yieldOperands; 708 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 709 yieldOperands.push_back(operand); 710 rewriter.eraseOp(forOp.getBody()->getTerminator()); 711 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 712 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 713 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 714 rewriter.setInsertionPointAfter(innerWarp); 715 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 716 rewriter.eraseOp(forOp); 717 // Replace the warpOp result coming from the original ForOp. 718 for (const auto &res : llvm::enumerate(resultIdx)) { 719 warpOp.getResult(res.value()) 720 .replaceAllUsesWith(newForOp.getResult(res.index())); 721 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 722 } 723 return success(); 724 } 725 }; 726 727 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 728 /// The vector is reduced in parallel. Currently limited to vector<32x...> 729 /// values. Every lane reduces two scalars, 5 times in a row. 730 /// E.g.: 731 /// ``` 732 /// %r = vector_ext.warp_execute_on_lane_0(%laneid) -> (f32) { 733 /// %0 = "some_def"() : () -> (vector<32xf32>) 734 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 735 /// vector_ext.yield %1 : f32 736 /// } 737 /// ``` 738 /// is lowered to: 739 /// ``` 740 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid) -> (vector<1xf32>) { 741 /// %1 = "some_def"() : () -> (vector<32xf32>) 742 /// vector_ext.yield %1 : vector<32xf32> 743 /// } 744 /// %a = vector.extract %0[0] : vector<1xf32> 745 /// %r0, %s0 = gpu.shuffle xor %e, %c1, %c32 : f32 746 /// %a0 = arith.addf %a, %r0 : f32 747 /// %r1, %s1 = gpu.shuffle xor %a0, %c2, %c32 : f32 748 /// %a1 = arith.addf %a0, %r1 : f32 749 /// ... 750 /// %r4, %s4 = gpu.shuffle xor %a3, %c16, %c32 : f32 751 /// %r = arith.addf %a3, %r4 : f32 752 /// ``` 753 struct ReductionToGPUWarpShuffle 754 : public OpRewritePattern<WarpExecuteOnLane0Op> { 755 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 756 757 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 758 PatternRewriter &rewriter) const override { 759 OpOperand *yieldOperand = getWarpResult( 760 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 761 if (!yieldOperand) 762 return failure(); 763 764 auto reductionOp = 765 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 766 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); 767 // Only rank 1 vectors supported. 768 if (vectorType.getRank() != 1) 769 return rewriter.notifyMatchFailure( 770 warpOp, "Only rank 1 reductions can be distributed."); 771 // Only warp_size-sized vectors supported. 772 if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize()) 773 return rewriter.notifyMatchFailure( 774 warpOp, "Reduction vector dimension must match was size."); 775 // Only f32 and i32 element types are supported. 776 if (!reductionOp.getType().isF32() && 777 !reductionOp.getType().isSignlessInteger(32)) 778 return rewriter.notifyMatchFailure( 779 warpOp, 780 "Reduction distribution currently only supports 32bits types."); 781 782 Location yieldLoc = yieldOperand->getOwner()->getLoc(); 783 784 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 785 unsigned operandIndex = yieldOperand->getOperandNumber(); 786 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 787 SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())}; 788 unsigned numResults = warpOp.getNumResults(); 789 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 790 rewriter, warpOp, yieldValues, retTypes); 791 rewriter.setInsertionPointAfter(newWarpOp); 792 793 // Every lane has one scalar value. These should be reduced. 794 Value laneValVec = newWarpOp.getResult(numResults); 795 Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0); 796 797 // Parallel reduction using butterfly shuffles. 798 for (uint64_t i = 1; i < newWarpOp.getWarpSize(); i <<= 1) { 799 Value shuffled = 800 rewriter 801 .create<gpu::ShuffleOp>(reductionOp.getLoc(), laneVal, i, 802 /*width=*/newWarpOp.getWarpSize(), 803 /*mode=*/gpu::ShuffleMode::XOR) 804 .result(); 805 laneVal = makeArithReduction(rewriter, reductionOp.getLoc(), 806 reductionOp.getKind(), laneVal, shuffled); 807 } 808 809 newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); 810 return success(); 811 } 812 }; 813 814 } // namespace 815 816 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 817 RewritePatternSet &patterns, 818 const WarpExecuteOnLane0LoweringOptions &options) { 819 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 820 } 821 822 void mlir::vector::populateDistributeTransferWriteOpPatterns( 823 RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { 824 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 825 } 826 827 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 828 RewritePatternSet &patterns) { 829 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 830 WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>( 831 patterns.getContext()); 832 } 833 834 void mlir::vector::populateReductionToGPUWarpShufflePatterns( 835 RewritePatternSet &patterns) { 836 patterns.add<ReductionToGPUWarpShuffle>(patterns.getContext()); 837 } 838 839 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 840 Block *body = warpOp.getBody(); 841 842 // Keep track of the ops we want to hoist. 843 llvm::SmallSetVector<Operation *, 8> opsToMove; 844 845 // Helper to check if a value is or will be defined outside of the region. 846 auto isDefinedOutsideOfBody = [&](Value value) { 847 auto *definingOp = value.getDefiningOp(); 848 return (definingOp && opsToMove.count(definingOp)) || 849 warpOp.isDefinedOutsideOfRegion(value); 850 }; 851 852 // Do not use walk here, as we do not want to go into nested regions and hoist 853 // operations from there. 854 for (auto &op : body->without_terminator()) { 855 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 856 return result.getType().isa<VectorType>(); 857 }); 858 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 859 opsToMove.insert(&op); 860 } 861 862 // Move all the ops marked as uniform outside of the region. 863 for (Operation *op : opsToMove) 864 op->moveBefore(warpOp); 865 } 866