1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/IR/BlockAndValueMapping.h" 15 #include "mlir/Transforms/SideEffectUtils.h" 16 17 using namespace mlir; 18 using namespace mlir::vector; 19 20 static LogicalResult 21 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 22 const WarpExecuteOnLane0LoweringOptions &options) { 23 assert(warpOp.getBodyRegion().hasOneBlock() && 24 "expected WarpOp with single block"); 25 Block *warpOpBody = &warpOp.getBodyRegion().front(); 26 Location loc = warpOp.getLoc(); 27 28 // Passed all checks. Start rewriting. 29 OpBuilder::InsertionGuard g(rewriter); 30 rewriter.setInsertionPoint(warpOp); 31 32 // Create scf.if op. 33 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 34 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 35 warpOp.getLaneid(), c0); 36 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 37 /*withElseRegion=*/false); 38 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 39 40 // Store vectors that are defined outside of warpOp into the scratch pad 41 // buffer. 42 SmallVector<Value> bbArgReplacements; 43 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 44 Value val = it.value(); 45 Value bbArg = warpOpBody->getArgument(it.index()); 46 47 rewriter.setInsertionPoint(ifOp); 48 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 49 bbArg.getType()); 50 51 // Store arg vector into buffer. 52 rewriter.setInsertionPoint(ifOp); 53 auto vectorType = val.getType().cast<VectorType>(); 54 int64_t storeSize = vectorType.getShape()[0]; 55 Value storeOffset = rewriter.create<arith::MulIOp>( 56 loc, warpOp.getLaneid(), 57 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 58 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 59 60 // Load bbArg vector from buffer. 61 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 62 auto bbArgType = bbArg.getType().cast<VectorType>(); 63 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 64 bbArgReplacements.push_back(loadOp); 65 } 66 67 // Insert sync after all the stores and before all the loads. 68 if (!warpOp.getArgs().empty()) { 69 rewriter.setInsertionPoint(ifOp); 70 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 71 } 72 73 // Move body of warpOp to ifOp. 74 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 75 76 // Rewrite terminator and compute replacements of WarpOp results. 77 SmallVector<Value> replacements; 78 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 79 Location yieldLoc = yieldOp.getLoc(); 80 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 81 Value val = it.value(); 82 Type resultType = warpOp->getResultTypes()[it.index()]; 83 rewriter.setInsertionPoint(ifOp); 84 Value buffer = options.warpAllocationFn(warpOp->getLoc(), rewriter, warpOp, 85 val.getType()); 86 87 // Store yielded value into buffer. 88 rewriter.setInsertionPoint(yieldOp); 89 if (val.getType().isa<VectorType>()) 90 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 91 else 92 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 93 94 // Load value from buffer (after warpOp). 95 rewriter.setInsertionPointAfter(ifOp); 96 if (resultType == val.getType()) { 97 // Result type and yielded value type are the same. This is a broadcast. 98 // E.g.: 99 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 100 // vector.yield %cst : f32 101 // } 102 // Both types are f32. The constant %cst is broadcasted to all lanes. 103 // This is described in more detail in the documentation of the op. 104 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 105 replacements.push_back(loadOp); 106 } else { 107 auto loadedVectorType = resultType.cast<VectorType>(); 108 int64_t loadSize = loadedVectorType.getShape()[0]; 109 110 // loadOffset = laneid * loadSize 111 Value loadOffset = rewriter.create<arith::MulIOp>( 112 loc, warpOp.getLaneid(), 113 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 114 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 115 buffer, loadOffset); 116 replacements.push_back(loadOp); 117 } 118 } 119 120 // Insert sync after all the stores and before all the loads. 121 if (!yieldOp.operands().empty()) { 122 rewriter.setInsertionPointAfter(ifOp); 123 options.warpSyncronizationFn(warpOp->getLoc(), rewriter, warpOp); 124 } 125 126 // Delete terminator and add empty scf.yield. 127 rewriter.eraseOp(yieldOp); 128 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 129 rewriter.create<scf::YieldOp>(yieldLoc); 130 131 // Compute replacements for WarpOp results. 132 rewriter.replaceOp(warpOp, replacements); 133 134 return success(); 135 } 136 137 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 138 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 139 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 140 ValueRange newYieldedValues, TypeRange newReturnTypes) { 141 // Create a new op before the existing one, with the extra operands. 142 OpBuilder::InsertionGuard g(rewriter); 143 rewriter.setInsertionPoint(warpOp); 144 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 145 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 146 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 147 148 Region &opBody = warpOp.getBodyRegion(); 149 Region &newOpBody = newWarpOp.getBodyRegion(); 150 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 151 auto yield = 152 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 153 154 rewriter.updateRootInPlace( 155 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 156 return newWarpOp; 157 } 158 159 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 160 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 161 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 162 ValueRange newYieldedValues, TypeRange newReturnTypes) { 163 SmallVector<Type> types(warpOp.getResultTypes().begin(), 164 warpOp.getResultTypes().end()); 165 types.append(newReturnTypes.begin(), newReturnTypes.end()); 166 auto yield = cast<vector::YieldOp>( 167 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 168 SmallVector<Value> yieldValues(yield.getOperands().begin(), 169 yield.getOperands().end()); 170 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 171 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 172 rewriter, warpOp, yieldValues, types); 173 rewriter.replaceOp(warpOp, 174 newWarpOp.getResults().take_front(warpOp.getNumResults())); 175 return newWarpOp; 176 } 177 178 /// Helper to know if an op can be hoisted out of the region. 179 static bool canBeHoisted(Operation *op, 180 function_ref<bool(Value)> definedOutside) { 181 return llvm::all_of(op->getOperands(), definedOutside) && 182 isSideEffectFree(op) && op->getNumRegions() == 0; 183 } 184 185 /// Return a value yielded by `warpOp` which statifies the filter lamdba 186 /// condition and is not dead. 187 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 188 std::function<bool(Operation *)> fn) { 189 auto yield = cast<vector::YieldOp>( 190 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 191 for (OpOperand &yieldOperand : yield->getOpOperands()) { 192 Value yieldValues = yieldOperand.get(); 193 Operation *definedOp = yieldValues.getDefiningOp(); 194 if (definedOp && fn(definedOp)) { 195 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 196 return &yieldOperand; 197 } 198 } 199 return {}; 200 } 201 202 // Clones `op` into a new operation that takes `operands` and returns 203 // `resultTypes`. 204 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 205 Location loc, Operation *op, 206 ArrayRef<Value> operands, 207 ArrayRef<Type> resultTypes) { 208 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 209 op->getAttrs()); 210 return rewriter.create(res); 211 } 212 213 /// Currently the distribution map is implicit based on the vector shape. In the 214 /// future it will be part of the op. 215 /// Example: 216 /// ``` 217 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 218 /// ... 219 /// vector.yield %3 : vector<32x16x64xf32> 220 /// } 221 /// ``` 222 /// Would have an implicit map of: 223 /// `(d0, d1, d2) -> (d0, d2)` 224 static AffineMap calculateImplicitMap(Value yield, Value ret) { 225 auto srcType = yield.getType().cast<VectorType>(); 226 auto dstType = ret.getType().cast<VectorType>(); 227 SmallVector<AffineExpr> perm; 228 // Check which dimensions of the yield value are different than the dimensions 229 // of the result to know the distributed dimensions. Then associate each 230 // distributed dimension to an ID in order. 231 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 232 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 233 perm.push_back(getAffineDimExpr(i, yield.getContext())); 234 } 235 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 236 return map; 237 } 238 239 namespace { 240 241 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 242 WarpOpToScfForPattern(MLIRContext *context, 243 const WarpExecuteOnLane0LoweringOptions &options, 244 PatternBenefit benefit = 1) 245 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 246 options(options) {} 247 248 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 249 PatternRewriter &rewriter) const override { 250 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 251 } 252 253 private: 254 const WarpExecuteOnLane0LoweringOptions &options; 255 }; 256 257 /// Distribute transfer_write ops based on the affine map returned by 258 /// `distributionMapFn`. 259 /// Example: 260 /// ``` 261 /// %0 = vector.warp_execute_on_lane_0(%id){ 262 /// ... 263 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 264 /// vector.yield 265 /// } 266 /// ``` 267 /// To 268 /// ``` 269 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 270 /// ... 271 /// vector.yield %v : vector<32xf32> 272 /// } 273 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 274 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 275 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 276 PatternBenefit b = 1) 277 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 278 distributionMapFn(fn) {} 279 280 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 281 /// are multiples of the distribution ratio are supported at the moment. 282 LogicalResult tryDistributeOp(RewriterBase &rewriter, 283 vector::TransferWriteOp writeOp, 284 WarpExecuteOnLane0Op warpOp) const { 285 AffineMap map = distributionMapFn(writeOp); 286 SmallVector<int64_t> targetShape(writeOp.getVectorType().getShape().begin(), 287 writeOp.getVectorType().getShape().end()); 288 assert(map.getNumResults() == 1 && 289 "multi-dim distribution not implemented yet"); 290 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 291 unsigned position = map.getDimPosition(i); 292 if (targetShape[position] % warpOp.getWarpSize() != 0) 293 return failure(); 294 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 295 } 296 VectorType targetType = 297 VectorType::get(targetShape, writeOp.getVectorType().getElementType()); 298 299 SmallVector<Value> yieldValues = {writeOp.getVector()}; 300 SmallVector<Type> retTypes = {targetType}; 301 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 302 rewriter, warpOp, yieldValues, retTypes); 303 rewriter.setInsertionPointAfter(newWarpOp); 304 305 // Move op outside of region: Insert clone at the insertion point and delete 306 // the old op. 307 auto newWriteOp = 308 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 309 rewriter.eraseOp(writeOp); 310 311 rewriter.setInsertionPoint(newWriteOp); 312 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 313 Location loc = newWriteOp.getLoc(); 314 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 315 newWriteOp.getIndices().end()); 316 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 317 AffineExpr d0, d1; 318 bindDims(newWarpOp.getContext(), d0, d1); 319 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 320 if (!indexExpr) 321 continue; 322 unsigned indexPos = indexExpr.getPosition(); 323 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 324 auto scale = 325 getAffineConstantExpr(targetShape[vectorPos], newWarpOp.getContext()); 326 indices[indexPos] = 327 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 328 {indices[indexPos], newWarpOp.getLaneid()}); 329 } 330 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 331 newWriteOp.getIndicesMutable().assign(indices); 332 333 return success(); 334 } 335 336 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 337 LogicalResult tryExtractOp(RewriterBase &rewriter, 338 vector::TransferWriteOp writeOp, 339 WarpExecuteOnLane0Op warpOp) const { 340 Location loc = writeOp.getLoc(); 341 VectorType vecType = writeOp.getVectorType(); 342 343 // Only vector<1x> is supported at the moment. 344 if (vecType.getShape().size() != 1 || vecType.getShape()[0] != 1) 345 return failure(); 346 347 // Do not process warp ops that contain only TransferWriteOps. 348 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 349 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 350 })) 351 return failure(); 352 353 SmallVector<Value> yieldValues = {writeOp.getVector()}; 354 SmallVector<Type> retTypes = {vecType}; 355 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 356 rewriter, warpOp, yieldValues, retTypes); 357 rewriter.setInsertionPointAfter(newWarpOp); 358 359 // Create a second warp op that contains only writeOp. 360 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 361 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 362 Block &body = secondWarpOp.getBodyRegion().front(); 363 rewriter.setInsertionPointToStart(&body); 364 auto newWriteOp = 365 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 366 newWriteOp.getVectorMutable().assign( 367 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 368 rewriter.eraseOp(writeOp); 369 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 370 return success(); 371 } 372 373 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 374 PatternRewriter &rewriter) const override { 375 // Ops with mask not supported yet. 376 if (writeOp.getMask()) 377 return failure(); 378 379 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 380 if (!warpOp) 381 return failure(); 382 383 // There must be no op with a side effect after writeOp. 384 Operation *nextOp = writeOp.getOperation(); 385 while ((nextOp = nextOp->getNextNode())) 386 if (!isSideEffectFree(nextOp)) 387 return failure(); 388 389 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 390 return writeOp.getVector() == value || 391 warpOp.isDefinedOutsideOfRegion(value); 392 })) 393 return failure(); 394 395 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 396 return success(); 397 398 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 399 return success(); 400 401 return failure(); 402 } 403 404 private: 405 DistributionMapFn distributionMapFn; 406 }; 407 408 /// Sink out elementwise op feeding into a warp op yield. 409 /// ``` 410 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 411 /// ... 412 /// %3 = arith.addf %1, %2 : vector<32xf32> 413 /// vector.yield %3 : vector<32xf32> 414 /// } 415 /// ``` 416 /// To 417 /// ``` 418 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 419 /// vector<1xf32>, vector<1xf32>) { 420 /// ... 421 /// %4 = arith.addf %2, %3 : vector<32xf32> 422 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 423 /// vector<32xf32> 424 /// } 425 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 426 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 427 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 428 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 429 PatternRewriter &rewriter) const override { 430 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 431 return OpTrait::hasElementwiseMappableTraits(op); 432 }); 433 if (!yieldOperand) 434 return failure(); 435 Operation *elementWise = yieldOperand->get().getDefiningOp(); 436 unsigned operandIndex = yieldOperand->getOperandNumber(); 437 Value distributedVal = warpOp.getResult(operandIndex); 438 SmallVector<Value> yieldValues; 439 SmallVector<Type> retTypes; 440 Location loc = warpOp.getLoc(); 441 for (OpOperand &operand : elementWise->getOpOperands()) { 442 Type targetType; 443 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 444 // If the result type is a vector, the operands must also be vectors. 445 auto operandType = operand.get().getType().cast<VectorType>(); 446 targetType = 447 VectorType::get(vecType.getShape(), operandType.getElementType()); 448 } else { 449 auto operandType = operand.get().getType(); 450 assert(!operandType.isa<VectorType>() && 451 "unexpected yield of vector from op with scalar result type"); 452 targetType = operandType; 453 } 454 retTypes.push_back(targetType); 455 yieldValues.push_back(operand.get()); 456 } 457 unsigned numResults = warpOp.getNumResults(); 458 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 459 rewriter, warpOp, yieldValues, retTypes); 460 rewriter.setInsertionPointAfter(newWarpOp); 461 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 462 elementWise->getOperands().end()); 463 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 464 newOperands[i] = newWarpOp.getResult(i + numResults); 465 } 466 OpBuilder::InsertionGuard g(rewriter); 467 rewriter.setInsertionPointAfter(newWarpOp); 468 Operation *newOp = cloneOpWithOperandsAndTypes( 469 rewriter, loc, elementWise, newOperands, 470 {newWarpOp.getResult(operandIndex).getType()}); 471 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 472 return success(); 473 } 474 }; 475 476 /// Sink out transfer_read op feeding into a warp op yield. 477 /// ``` 478 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 479 /// ... 480 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 481 // vector<32xf32> 482 /// vector.yield %2 : vector<32xf32> 483 /// } 484 /// ``` 485 /// To 486 /// ``` 487 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 488 /// vector<1xf32>, vector<1xf32>) { 489 /// ... 490 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 491 /// vector<32xf32> vector.yield %2 : vector<32xf32> 492 /// } 493 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 494 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 495 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 496 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 497 PatternRewriter &rewriter) const override { 498 OpOperand *operand = getWarpResult( 499 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 500 if (!operand) 501 return failure(); 502 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 503 unsigned operandIndex = operand->getOperandNumber(); 504 Value distributedVal = warpOp.getResult(operandIndex); 505 506 SmallVector<Value, 4> indices(read.getIndices().begin(), 507 read.getIndices().end()); 508 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 509 AffineMap indexMap = map.compose(read.getPermutationMap()); 510 OpBuilder::InsertionGuard g(rewriter); 511 rewriter.setInsertionPointAfter(warpOp); 512 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 513 AffineExpr d0, d1; 514 bindDims(read.getContext(), d0, d1); 515 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 516 if (!indexExpr) 517 continue; 518 unsigned indexPos = indexExpr.getPosition(); 519 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 520 int64_t scale = 521 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 522 indices[indexPos] = 523 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 524 {indices[indexPos], warpOp.getLaneid()}); 525 } 526 Value newRead = rewriter.create<vector::TransferReadOp>( 527 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 528 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 529 read.getInBoundsAttr()); 530 distributedVal.replaceAllUsesWith(newRead); 531 return success(); 532 } 533 }; 534 535 /// Remove any result that has no use along with the matching yieldOp operand. 536 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 537 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 538 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 539 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 540 PatternRewriter &rewriter) const override { 541 SmallVector<Type> resultTypes; 542 SmallVector<Value> yieldValues; 543 auto yield = cast<vector::YieldOp>( 544 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 545 for (OpResult result : warpOp.getResults()) { 546 if (result.use_empty()) 547 continue; 548 resultTypes.push_back(result.getType()); 549 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 550 } 551 if (yield.getNumOperands() == yieldValues.size()) 552 return failure(); 553 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 554 rewriter, warpOp, yieldValues, resultTypes); 555 unsigned resultIndex = 0; 556 for (OpResult result : warpOp.getResults()) { 557 if (result.use_empty()) 558 continue; 559 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 560 } 561 rewriter.eraseOp(warpOp); 562 return success(); 563 } 564 }; 565 566 // If an operand is directly yielded out of the region we can forward it 567 // directly and it doesn't need to go through the region. 568 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 569 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 570 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 571 PatternRewriter &rewriter) const override { 572 SmallVector<Type> resultTypes; 573 SmallVector<Value> yieldValues; 574 auto yield = cast<vector::YieldOp>( 575 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 576 Value valForwarded; 577 unsigned resultIndex; 578 for (OpOperand &operand : yield->getOpOperands()) { 579 Value result = warpOp.getResult(operand.getOperandNumber()); 580 if (result.use_empty()) 581 continue; 582 583 // Assume all the values coming from above are uniform. 584 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 585 if (result.getType() != operand.get().getType()) 586 continue; 587 valForwarded = operand.get(); 588 resultIndex = operand.getOperandNumber(); 589 break; 590 } 591 auto arg = operand.get().dyn_cast<BlockArgument>(); 592 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 593 continue; 594 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 595 if (result.getType() != warpOperand.getType()) 596 continue; 597 valForwarded = warpOperand; 598 resultIndex = operand.getOperandNumber(); 599 break; 600 } 601 if (!valForwarded) 602 return failure(); 603 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 604 return success(); 605 } 606 }; 607 608 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 609 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 610 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 611 PatternRewriter &rewriter) const override { 612 OpOperand *operand = getWarpResult( 613 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 614 if (!operand) 615 return failure(); 616 unsigned int operandNumber = operand->getOperandNumber(); 617 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 618 Location loc = broadcastOp.getLoc(); 619 auto destVecType = 620 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 621 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 622 rewriter, warpOp, {broadcastOp.getSource()}, 623 {broadcastOp.getSource().getType()}); 624 rewriter.setInsertionPointAfter(newWarpOp); 625 Value broadcasted = rewriter.create<vector::BroadcastOp>( 626 loc, destVecType, newWarpOp->getResults().back()); 627 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 628 629 return success(); 630 } 631 }; 632 633 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 634 /// the scf.ForOp is the last operation in the region so that it doesn't change 635 /// the order of execution. This creates a new scf.for region after the 636 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 637 /// WarpExecuteOnLane0Op region. Example: 638 /// ``` 639 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 640 /// ... 641 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 642 /// -> (vector<128xf32>) { 643 /// ... 644 /// scf.yield %r : vector<128xf32> 645 /// } 646 /// vector.yield %v1 : vector<128xf32> 647 /// } 648 /// ``` 649 /// To: 650 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 651 /// ... 652 /// vector.yield %v : vector<128xf32> 653 /// } 654 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 655 /// -> (vector<4xf32>) { 656 /// %iw = vector.warp_execute_on_lane_0(%laneid) 657 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 658 /// ^bb0(%arg: vector<128xf32>): 659 /// ... 660 /// vector.yield %ir : vector<128xf32> 661 /// } 662 /// scf.yield %iw : vector<4xf32> 663 /// } 664 /// ``` 665 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 666 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 667 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 668 PatternRewriter &rewriter) const override { 669 auto yield = cast<vector::YieldOp>( 670 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 671 // Only pick up forOp if it is the last op in the region. 672 Operation *lastNode = yield->getPrevNode(); 673 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 674 if (!forOp) 675 return failure(); 676 SmallVector<Value> newOperands; 677 SmallVector<unsigned> resultIdx; 678 // Collect all the outputs coming from the forOp. 679 for (OpOperand &yieldOperand : yield->getOpOperands()) { 680 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 681 continue; 682 auto forResult = yieldOperand.get().cast<OpResult>(); 683 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 684 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 685 resultIdx.push_back(yieldOperand.getOperandNumber()); 686 } 687 OpBuilder::InsertionGuard g(rewriter); 688 rewriter.setInsertionPointAfter(warpOp); 689 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 690 // inside. 691 auto newForOp = rewriter.create<scf::ForOp>( 692 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 693 forOp.getStep(), newOperands); 694 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 695 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 696 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 697 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 698 forOp.getResultTypes()); 699 700 SmallVector<Value> argMapping; 701 argMapping.push_back(newForOp.getInductionVar()); 702 for (Value args : innerWarp.getBody()->getArguments()) { 703 argMapping.push_back(args); 704 } 705 SmallVector<Value> yieldOperands; 706 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 707 yieldOperands.push_back(operand); 708 rewriter.eraseOp(forOp.getBody()->getTerminator()); 709 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 710 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 711 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 712 rewriter.setInsertionPointAfter(innerWarp); 713 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 714 rewriter.eraseOp(forOp); 715 // Replace the warpOp result coming from the original ForOp. 716 for (const auto &res : llvm::enumerate(resultIdx)) { 717 warpOp.getResult(res.value()) 718 .replaceAllUsesWith(newForOp.getResult(res.index())); 719 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 720 } 721 return success(); 722 } 723 }; 724 725 } // namespace 726 727 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 728 RewritePatternSet &patterns, 729 const WarpExecuteOnLane0LoweringOptions &options) { 730 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 731 } 732 733 void mlir::vector::populateDistributeTransferWriteOpPatterns( 734 RewritePatternSet &patterns, DistributionMapFn distributionMapFn) { 735 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 736 } 737 738 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 739 RewritePatternSet &patterns) { 740 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 741 WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>( 742 patterns.getContext()); 743 } 744 745 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 746 Block *body = warpOp.getBody(); 747 748 // Keep track of the ops we want to hoist. 749 llvm::SmallSetVector<Operation *, 8> opsToMove; 750 751 // Helper to check if a value is or will be defined outside of the region. 752 auto isDefinedOutsideOfBody = [&](Value value) { 753 auto *definingOp = value.getDefiningOp(); 754 return (definingOp && opsToMove.count(definingOp)) || 755 warpOp.isDefinedOutsideOfRegion(value); 756 }; 757 758 // Do not use walk here, as we do not want to go into nested regions and hoist 759 // operations from there. 760 for (auto &op : body->without_terminator()) { 761 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 762 return result.getType().isa<VectorType>(); 763 }); 764 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 765 opsToMove.insert(&op); 766 } 767 768 // Move all the ops marked as uniform outside of the region. 769 for (Operation *op : opsToMove) 770 op->moveBefore(warpOp); 771 } 772