1 //===- VectorDistribute.cpp - patterns to do vector distribution ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Affine/IR/AffineOps.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/MemRef/IR/MemRef.h" 12 #include "mlir/Dialect/SCF/IR/SCF.h" 13 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" 14 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/Transforms/SideEffectUtils.h" 17 18 #include <utility> 19 20 using namespace mlir; 21 using namespace mlir::vector; 22 23 static LogicalResult 24 rewriteWarpOpToScfFor(RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 25 const WarpExecuteOnLane0LoweringOptions &options) { 26 assert(warpOp.getBodyRegion().hasOneBlock() && 27 "expected WarpOp with single block"); 28 Block *warpOpBody = &warpOp.getBodyRegion().front(); 29 Location loc = warpOp.getLoc(); 30 31 // Passed all checks. Start rewriting. 32 OpBuilder::InsertionGuard g(rewriter); 33 rewriter.setInsertionPoint(warpOp); 34 35 // Create scf.if op. 36 Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); 37 Value isLane0 = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, 38 warpOp.getLaneid(), c0); 39 auto ifOp = rewriter.create<scf::IfOp>(loc, isLane0, 40 /*withElseRegion=*/false); 41 rewriter.eraseOp(ifOp.thenBlock()->getTerminator()); 42 43 // Store vectors that are defined outside of warpOp into the scratch pad 44 // buffer. 45 SmallVector<Value> bbArgReplacements; 46 for (const auto &it : llvm::enumerate(warpOp.getArgs())) { 47 Value val = it.value(); 48 Value bbArg = warpOpBody->getArgument(it.index()); 49 50 rewriter.setInsertionPoint(ifOp); 51 Value buffer = 52 options.warpAllocationFn(loc, rewriter, warpOp, bbArg.getType()); 53 54 // Store arg vector into buffer. 55 rewriter.setInsertionPoint(ifOp); 56 auto vectorType = val.getType().cast<VectorType>(); 57 int64_t storeSize = vectorType.getShape()[0]; 58 Value storeOffset = rewriter.create<arith::MulIOp>( 59 loc, warpOp.getLaneid(), 60 rewriter.create<arith::ConstantIndexOp>(loc, storeSize)); 61 rewriter.create<vector::StoreOp>(loc, val, buffer, storeOffset); 62 63 // Load bbArg vector from buffer. 64 rewriter.setInsertionPointToStart(ifOp.thenBlock()); 65 auto bbArgType = bbArg.getType().cast<VectorType>(); 66 Value loadOp = rewriter.create<vector::LoadOp>(loc, bbArgType, buffer, c0); 67 bbArgReplacements.push_back(loadOp); 68 } 69 70 // Insert sync after all the stores and before all the loads. 71 if (!warpOp.getArgs().empty()) { 72 rewriter.setInsertionPoint(ifOp); 73 options.warpSyncronizationFn(loc, rewriter, warpOp); 74 } 75 76 // Move body of warpOp to ifOp. 77 rewriter.mergeBlocks(warpOpBody, ifOp.thenBlock(), bbArgReplacements); 78 79 // Rewrite terminator and compute replacements of WarpOp results. 80 SmallVector<Value> replacements; 81 auto yieldOp = cast<vector::YieldOp>(ifOp.thenBlock()->getTerminator()); 82 Location yieldLoc = yieldOp.getLoc(); 83 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 84 Value val = it.value(); 85 Type resultType = warpOp->getResultTypes()[it.index()]; 86 rewriter.setInsertionPoint(ifOp); 87 Value buffer = 88 options.warpAllocationFn(loc, rewriter, warpOp, val.getType()); 89 90 // Store yielded value into buffer. 91 rewriter.setInsertionPoint(yieldOp); 92 if (val.getType().isa<VectorType>()) 93 rewriter.create<vector::StoreOp>(yieldLoc, val, buffer, c0); 94 else 95 rewriter.create<memref::StoreOp>(yieldLoc, val, buffer, c0); 96 97 // Load value from buffer (after warpOp). 98 rewriter.setInsertionPointAfter(ifOp); 99 if (resultType == val.getType()) { 100 // Result type and yielded value type are the same. This is a broadcast. 101 // E.g.: 102 // %r = vector.warp_execute_on_lane_0(...) -> (f32) { 103 // vector.yield %cst : f32 104 // } 105 // Both types are f32. The constant %cst is broadcasted to all lanes. 106 // This is described in more detail in the documentation of the op. 107 Value loadOp = rewriter.create<memref::LoadOp>(loc, buffer, c0); 108 replacements.push_back(loadOp); 109 } else { 110 auto loadedVectorType = resultType.cast<VectorType>(); 111 int64_t loadSize = loadedVectorType.getShape()[0]; 112 113 // loadOffset = laneid * loadSize 114 Value loadOffset = rewriter.create<arith::MulIOp>( 115 loc, warpOp.getLaneid(), 116 rewriter.create<arith::ConstantIndexOp>(loc, loadSize)); 117 Value loadOp = rewriter.create<vector::LoadOp>(loc, loadedVectorType, 118 buffer, loadOffset); 119 replacements.push_back(loadOp); 120 } 121 } 122 123 // Insert sync after all the stores and before all the loads. 124 if (!yieldOp.operands().empty()) { 125 rewriter.setInsertionPointAfter(ifOp); 126 options.warpSyncronizationFn(loc, rewriter, warpOp); 127 } 128 129 // Delete terminator and add empty scf.yield. 130 rewriter.eraseOp(yieldOp); 131 rewriter.setInsertionPointToEnd(ifOp.thenBlock()); 132 rewriter.create<scf::YieldOp>(yieldLoc); 133 134 // Compute replacements for WarpOp results. 135 rewriter.replaceOp(warpOp, replacements); 136 137 return success(); 138 } 139 140 /// Helper to create a new WarpExecuteOnLane0Op with different signature. 141 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( 142 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 143 ValueRange newYieldedValues, TypeRange newReturnTypes) { 144 // Create a new op before the existing one, with the extra operands. 145 OpBuilder::InsertionGuard g(rewriter); 146 rewriter.setInsertionPoint(warpOp); 147 auto newWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 148 warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), 149 warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); 150 151 Region &opBody = warpOp.getBodyRegion(); 152 Region &newOpBody = newWarpOp.getBodyRegion(); 153 Block &newOpFirstBlock = newOpBody.front(); 154 rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); 155 rewriter.eraseBlock(&newOpFirstBlock); 156 assert(newWarpOp.getWarpRegion().hasOneBlock() && 157 "expected WarpOp with single block"); 158 159 auto yield = 160 cast<vector::YieldOp>(newOpBody.getBlocks().begin()->getTerminator()); 161 162 rewriter.updateRootInPlace( 163 yield, [&]() { yield.operandsMutable().assign(newYieldedValues); }); 164 return newWarpOp; 165 } 166 167 /// Helper to create a new WarpExecuteOnLane0Op region with extra outputs. 168 static WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( 169 RewriterBase &rewriter, WarpExecuteOnLane0Op warpOp, 170 ValueRange newYieldedValues, TypeRange newReturnTypes) { 171 SmallVector<Type> types(warpOp.getResultTypes().begin(), 172 warpOp.getResultTypes().end()); 173 types.append(newReturnTypes.begin(), newReturnTypes.end()); 174 auto yield = cast<vector::YieldOp>( 175 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 176 SmallVector<Value> yieldValues(yield.getOperands().begin(), 177 yield.getOperands().end()); 178 yieldValues.append(newYieldedValues.begin(), newYieldedValues.end()); 179 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 180 rewriter, warpOp, yieldValues, types); 181 rewriter.replaceOp(warpOp, 182 newWarpOp.getResults().take_front(warpOp.getNumResults())); 183 return newWarpOp; 184 } 185 186 /// Helper to know if an op can be hoisted out of the region. 187 static bool canBeHoisted(Operation *op, 188 function_ref<bool(Value)> definedOutside) { 189 return llvm::all_of(op->getOperands(), definedOutside) && 190 isSideEffectFree(op) && op->getNumRegions() == 0; 191 } 192 193 /// Return a value yielded by `warpOp` which statifies the filter lamdba 194 /// condition and is not dead. 195 static OpOperand *getWarpResult(WarpExecuteOnLane0Op warpOp, 196 std::function<bool(Operation *)> fn) { 197 auto yield = cast<vector::YieldOp>( 198 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 199 for (OpOperand &yieldOperand : yield->getOpOperands()) { 200 Value yieldValues = yieldOperand.get(); 201 Operation *definedOp = yieldValues.getDefiningOp(); 202 if (definedOp && fn(definedOp)) { 203 if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) 204 return &yieldOperand; 205 } 206 } 207 return {}; 208 } 209 210 // Clones `op` into a new operation that takes `operands` and returns 211 // `resultTypes`. 212 static Operation *cloneOpWithOperandsAndTypes(RewriterBase &rewriter, 213 Location loc, Operation *op, 214 ArrayRef<Value> operands, 215 ArrayRef<Type> resultTypes) { 216 OperationState res(loc, op->getName().getStringRef(), operands, resultTypes, 217 op->getAttrs()); 218 return rewriter.create(res); 219 } 220 221 /// Currently the distribution map is implicit based on the vector shape. In the 222 /// future it will be part of the op. 223 /// Example: 224 /// ``` 225 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1x16x2xf32>) { 226 /// ... 227 /// vector.yield %3 : vector<32x16x64xf32> 228 /// } 229 /// ``` 230 /// Would have an implicit map of: 231 /// `(d0, d1, d2) -> (d0, d2)` 232 static AffineMap calculateImplicitMap(Value yield, Value ret) { 233 auto srcType = yield.getType().cast<VectorType>(); 234 auto dstType = ret.getType().cast<VectorType>(); 235 SmallVector<AffineExpr> perm; 236 // Check which dimensions of the yield value are different than the dimensions 237 // of the result to know the distributed dimensions. Then associate each 238 // distributed dimension to an ID in order. 239 for (unsigned i = 0, e = srcType.getRank(); i < e; i++) { 240 if (srcType.getDimSize(i) != dstType.getDimSize(i)) 241 perm.push_back(getAffineDimExpr(i, yield.getContext())); 242 } 243 auto map = AffineMap::get(srcType.getRank(), 0, perm, yield.getContext()); 244 return map; 245 } 246 247 namespace { 248 249 struct WarpOpToScfForPattern : public OpRewritePattern<WarpExecuteOnLane0Op> { 250 WarpOpToScfForPattern(MLIRContext *context, 251 const WarpExecuteOnLane0LoweringOptions &options, 252 PatternBenefit benefit = 1) 253 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 254 options(options) {} 255 256 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 257 PatternRewriter &rewriter) const override { 258 return rewriteWarpOpToScfFor(rewriter, warpOp, options); 259 } 260 261 private: 262 const WarpExecuteOnLane0LoweringOptions &options; 263 }; 264 265 /// Clone `writeOp` assumed to be nested under `warpOp` into a new warp execute 266 /// op with the proper return type. 267 /// The new write op is updated to write the result of the new warp execute op. 268 /// The old `writeOp` is deleted. 269 static vector::TransferWriteOp cloneWriteOp(RewriterBase &rewriter, 270 WarpExecuteOnLane0Op warpOp, 271 vector::TransferWriteOp writeOp, 272 VectorType targetType) { 273 assert(writeOp->getParentOp() == warpOp && 274 "write must be nested immediately under warp"); 275 OpBuilder::InsertionGuard g(rewriter); 276 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 277 rewriter, warpOp, ValueRange{{writeOp.getVector()}}, 278 TypeRange{targetType}); 279 rewriter.setInsertionPointAfter(newWarpOp); 280 auto newWriteOp = 281 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 282 rewriter.eraseOp(writeOp); 283 newWriteOp.getVectorMutable().assign(newWarpOp.getResults().back()); 284 return newWriteOp; 285 } 286 287 /// Distribute transfer_write ops based on the affine map returned by 288 /// `distributionMapFn`. 289 /// Example: 290 /// ``` 291 /// %0 = vector.warp_execute_on_lane_0(%id){ 292 /// ... 293 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32> 294 /// vector.yield 295 /// } 296 /// ``` 297 /// To 298 /// ``` 299 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) { 300 /// ... 301 /// vector.yield %v : vector<32xf32> 302 /// } 303 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32> 304 struct WarpOpTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { 305 WarpOpTransferWrite(MLIRContext *ctx, DistributionMapFn fn, 306 PatternBenefit b = 1) 307 : OpRewritePattern<vector::TransferWriteOp>(ctx, b), 308 distributionMapFn(std::move(fn)) {} 309 310 /// Distribute the TransferWriteOp. Only 1D distributions and vector dims that 311 /// are multiples of the distribution ratio are supported at the moment. 312 LogicalResult tryDistributeOp(RewriterBase &rewriter, 313 vector::TransferWriteOp writeOp, 314 WarpExecuteOnLane0Op warpOp) const { 315 VectorType writtenVectorType = writeOp.getVectorType(); 316 317 // 1. If the write is 0-D, we just clone it into a new WarpExecuteOnLane0Op 318 // to separate it from the rest. 319 if (writtenVectorType.getRank() == 0) 320 return failure(); 321 322 // 2. Compute the distribution map. 323 AffineMap map = distributionMapFn(writeOp); 324 if (map.getNumResults() != 1) 325 return writeOp->emitError("multi-dim distribution not implemented yet"); 326 327 // 3. Compute the targetType using the distribution map. 328 SmallVector<int64_t> targetShape(writtenVectorType.getShape().begin(), 329 writtenVectorType.getShape().end()); 330 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { 331 unsigned position = map.getDimPosition(i); 332 if (targetShape[position] % warpOp.getWarpSize() != 0) 333 return failure(); 334 targetShape[position] = targetShape[position] / warpOp.getWarpSize(); 335 } 336 VectorType targetType = 337 VectorType::get(targetShape, writtenVectorType.getElementType()); 338 339 // 4. clone the write into a new WarpExecuteOnLane0Op to separate it from 340 // the rest. 341 vector::TransferWriteOp newWriteOp = 342 cloneWriteOp(rewriter, warpOp, writeOp, targetType); 343 344 // 5. Reindex the write using the distribution map. 345 auto newWarpOp = 346 newWriteOp.getVector().getDefiningOp<WarpExecuteOnLane0Op>(); 347 rewriter.setInsertionPoint(newWriteOp); 348 AffineMap indexMap = map.compose(newWriteOp.getPermutationMap()); 349 Location loc = newWriteOp.getLoc(); 350 SmallVector<Value> indices(newWriteOp.getIndices().begin(), 351 newWriteOp.getIndices().end()); 352 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 353 AffineExpr d0, d1; 354 bindDims(newWarpOp.getContext(), d0, d1); 355 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 356 if (!indexExpr) 357 continue; 358 unsigned indexPos = indexExpr.getPosition(); 359 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 360 auto scale = rewriter.getAffineConstantExpr(targetShape[vectorPos]); 361 indices[indexPos] = 362 makeComposedAffineApply(rewriter, loc, d0 + scale * d1, 363 {indices[indexPos], newWarpOp.getLaneid()}); 364 } 365 newWriteOp.getIndicesMutable().assign(indices); 366 367 return success(); 368 } 369 370 /// Extract TransferWriteOps of vector<1x> into a separate warp op. 371 LogicalResult tryExtractOp(RewriterBase &rewriter, 372 vector::TransferWriteOp writeOp, 373 WarpExecuteOnLane0Op warpOp) const { 374 Location loc = writeOp.getLoc(); 375 VectorType vecType = writeOp.getVectorType(); 376 377 // Only sink out vector of 1 element for now to not serialize large vector 378 // store. This can later be controlled by user. 379 if (vecType.getNumElements() != 1) 380 return failure(); 381 382 // Do not process warp ops that contain only TransferWriteOps. 383 if (llvm::all_of(warpOp.getOps(), [](Operation &op) { 384 return isa<vector::TransferWriteOp, vector::YieldOp>(&op); 385 })) 386 return failure(); 387 388 SmallVector<Value> yieldValues = {writeOp.getVector()}; 389 SmallVector<Type> retTypes = {vecType}; 390 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 391 rewriter, warpOp, yieldValues, retTypes); 392 rewriter.setInsertionPointAfter(newWarpOp); 393 394 // Create a second warp op that contains only writeOp. 395 auto secondWarpOp = rewriter.create<WarpExecuteOnLane0Op>( 396 loc, TypeRange(), newWarpOp.getLaneid(), newWarpOp.getWarpSize()); 397 Block &body = secondWarpOp.getBodyRegion().front(); 398 rewriter.setInsertionPointToStart(&body); 399 auto newWriteOp = 400 cast<vector::TransferWriteOp>(rewriter.clone(*writeOp.getOperation())); 401 newWriteOp.getVectorMutable().assign( 402 newWarpOp.getResult(newWarpOp.getNumResults() - 1)); 403 rewriter.eraseOp(writeOp); 404 rewriter.create<vector::YieldOp>(newWarpOp.getLoc()); 405 return success(); 406 } 407 408 LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, 409 PatternRewriter &rewriter) const override { 410 // Ops with mask not supported yet. 411 if (writeOp.getMask()) 412 return failure(); 413 414 auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(writeOp->getParentOp()); 415 if (!warpOp) 416 return failure(); 417 418 // There must be no op with a side effect after writeOp. 419 Operation *nextOp = writeOp.getOperation(); 420 while ((nextOp = nextOp->getNextNode())) 421 if (!isSideEffectFree(nextOp)) 422 return failure(); 423 424 if (!llvm::all_of(writeOp->getOperands(), [&](Value value) { 425 return writeOp.getVector() == value || 426 warpOp.isDefinedOutsideOfRegion(value); 427 })) 428 return failure(); 429 430 if (succeeded(tryDistributeOp(rewriter, writeOp, warpOp))) 431 return success(); 432 433 if (succeeded(tryExtractOp(rewriter, writeOp, warpOp))) 434 return success(); 435 436 return failure(); 437 } 438 439 private: 440 DistributionMapFn distributionMapFn; 441 }; 442 443 /// Sink out elementwise op feeding into a warp op yield. 444 /// ``` 445 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 446 /// ... 447 /// %3 = arith.addf %1, %2 : vector<32xf32> 448 /// vector.yield %3 : vector<32xf32> 449 /// } 450 /// ``` 451 /// To 452 /// ``` 453 /// %r:3 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 454 /// vector<1xf32>, vector<1xf32>) { 455 /// ... 456 /// %4 = arith.addf %2, %3 : vector<32xf32> 457 /// vector.yield %4, %2, %3 : vector<32xf32>, vector<32xf32>, 458 /// vector<32xf32> 459 /// } 460 /// %0 = arith.addf %r#1, %r#2 : vector<1xf32> 461 struct WarpOpElementwise : public OpRewritePattern<WarpExecuteOnLane0Op> { 462 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 463 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 464 PatternRewriter &rewriter) const override { 465 OpOperand *yieldOperand = getWarpResult(warpOp, [](Operation *op) { 466 return OpTrait::hasElementwiseMappableTraits(op); 467 }); 468 if (!yieldOperand) 469 return failure(); 470 Operation *elementWise = yieldOperand->get().getDefiningOp(); 471 unsigned operandIndex = yieldOperand->getOperandNumber(); 472 Value distributedVal = warpOp.getResult(operandIndex); 473 SmallVector<Value> yieldValues; 474 SmallVector<Type> retTypes; 475 Location loc = warpOp.getLoc(); 476 for (OpOperand &operand : elementWise->getOpOperands()) { 477 Type targetType; 478 if (auto vecType = distributedVal.getType().dyn_cast<VectorType>()) { 479 // If the result type is a vector, the operands must also be vectors. 480 auto operandType = operand.get().getType().cast<VectorType>(); 481 targetType = 482 VectorType::get(vecType.getShape(), operandType.getElementType()); 483 } else { 484 auto operandType = operand.get().getType(); 485 assert(!operandType.isa<VectorType>() && 486 "unexpected yield of vector from op with scalar result type"); 487 targetType = operandType; 488 } 489 retTypes.push_back(targetType); 490 yieldValues.push_back(operand.get()); 491 } 492 unsigned numResults = warpOp.getNumResults(); 493 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 494 rewriter, warpOp, yieldValues, retTypes); 495 rewriter.setInsertionPointAfter(newWarpOp); 496 SmallVector<Value> newOperands(elementWise->getOperands().begin(), 497 elementWise->getOperands().end()); 498 for (unsigned i : llvm::seq(unsigned(0), elementWise->getNumOperands())) { 499 newOperands[i] = newWarpOp.getResult(i + numResults); 500 } 501 OpBuilder::InsertionGuard g(rewriter); 502 rewriter.setInsertionPointAfter(newWarpOp); 503 Operation *newOp = cloneOpWithOperandsAndTypes( 504 rewriter, loc, elementWise, newOperands, 505 {newWarpOp.getResult(operandIndex).getType()}); 506 newWarpOp.getResult(operandIndex).replaceAllUsesWith(newOp->getResult(0)); 507 return success(); 508 } 509 }; 510 511 /// Sink out transfer_read op feeding into a warp op yield. 512 /// ``` 513 /// %0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { 514 /// ... 515 // %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 516 // vector<32xf32> 517 /// vector.yield %2 : vector<32xf32> 518 /// } 519 /// ``` 520 /// To 521 /// ``` 522 /// %dead = vector.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>, 523 /// vector<1xf32>, vector<1xf32>) { 524 /// ... 525 /// %2 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, 526 /// vector<32xf32> vector.yield %2 : vector<32xf32> 527 /// } 528 /// %0 = vector.transfer_read %src[%c0], %cst : memref<1024xf32>, vector<1xf32> 529 struct WarpOpTransferRead : public OpRewritePattern<WarpExecuteOnLane0Op> { 530 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 531 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 532 PatternRewriter &rewriter) const override { 533 OpOperand *operand = getWarpResult( 534 warpOp, [](Operation *op) { return isa<vector::TransferReadOp>(op); }); 535 if (!operand) 536 return failure(); 537 auto read = operand->get().getDefiningOp<vector::TransferReadOp>(); 538 unsigned operandIndex = operand->getOperandNumber(); 539 Value distributedVal = warpOp.getResult(operandIndex); 540 541 SmallVector<Value, 4> indices(read.getIndices().begin(), 542 read.getIndices().end()); 543 AffineMap map = calculateImplicitMap(read.getResult(), distributedVal); 544 AffineMap indexMap = map.compose(read.getPermutationMap()); 545 OpBuilder::InsertionGuard g(rewriter); 546 rewriter.setInsertionPointAfter(warpOp); 547 for (auto it : llvm::zip(indexMap.getResults(), map.getResults())) { 548 AffineExpr d0, d1; 549 bindDims(read.getContext(), d0, d1); 550 auto indexExpr = std::get<0>(it).dyn_cast<AffineDimExpr>(); 551 if (!indexExpr) 552 continue; 553 unsigned indexPos = indexExpr.getPosition(); 554 unsigned vectorPos = std::get<1>(it).cast<AffineDimExpr>().getPosition(); 555 int64_t scale = 556 distributedVal.getType().cast<VectorType>().getDimSize(vectorPos); 557 indices[indexPos] = 558 makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, 559 {indices[indexPos], warpOp.getLaneid()}); 560 } 561 Value newRead = rewriter.create<vector::TransferReadOp>( 562 read.getLoc(), distributedVal.getType(), read.getSource(), indices, 563 read.getPermutationMapAttr(), read.getPadding(), read.getMask(), 564 read.getInBoundsAttr()); 565 distributedVal.replaceAllUsesWith(newRead); 566 return success(); 567 } 568 }; 569 570 /// Remove any result that has no use along with the matching yieldOp operand. 571 // TODO: Move this in WarpExecuteOnLane0Op canonicalization. 572 struct WarpOpDeadResult : public OpRewritePattern<WarpExecuteOnLane0Op> { 573 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 574 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 575 PatternRewriter &rewriter) const override { 576 SmallVector<Type> resultTypes; 577 SmallVector<Value> yieldValues; 578 auto yield = cast<vector::YieldOp>( 579 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 580 for (OpResult result : warpOp.getResults()) { 581 if (result.use_empty()) 582 continue; 583 resultTypes.push_back(result.getType()); 584 yieldValues.push_back(yield.getOperand(result.getResultNumber())); 585 } 586 if (yield.getNumOperands() == yieldValues.size()) 587 return failure(); 588 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( 589 rewriter, warpOp, yieldValues, resultTypes); 590 unsigned resultIndex = 0; 591 for (OpResult result : warpOp.getResults()) { 592 if (result.use_empty()) 593 continue; 594 result.replaceAllUsesWith(newWarpOp.getResult(resultIndex++)); 595 } 596 rewriter.eraseOp(warpOp); 597 return success(); 598 } 599 }; 600 601 // If an operand is directly yielded out of the region we can forward it 602 // directly and it doesn't need to go through the region. 603 struct WarpOpForwardOperand : public OpRewritePattern<WarpExecuteOnLane0Op> { 604 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 605 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 606 PatternRewriter &rewriter) const override { 607 SmallVector<Type> resultTypes; 608 SmallVector<Value> yieldValues; 609 auto yield = cast<vector::YieldOp>( 610 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 611 Value valForwarded; 612 unsigned resultIndex; 613 for (OpOperand &operand : yield->getOpOperands()) { 614 Value result = warpOp.getResult(operand.getOperandNumber()); 615 if (result.use_empty()) 616 continue; 617 618 // Assume all the values coming from above are uniform. 619 if (!warpOp.getBodyRegion().isAncestor(operand.get().getParentRegion())) { 620 if (result.getType() != operand.get().getType()) 621 continue; 622 valForwarded = operand.get(); 623 resultIndex = operand.getOperandNumber(); 624 break; 625 } 626 auto arg = operand.get().dyn_cast<BlockArgument>(); 627 if (!arg || arg.getOwner()->getParentOp() != warpOp.getOperation()) 628 continue; 629 Value warpOperand = warpOp.getArgs()[arg.getArgNumber()]; 630 if (result.getType() != warpOperand.getType()) 631 continue; 632 valForwarded = warpOperand; 633 resultIndex = operand.getOperandNumber(); 634 break; 635 } 636 if (!valForwarded) 637 return failure(); 638 warpOp.getResult(resultIndex).replaceAllUsesWith(valForwarded); 639 return success(); 640 } 641 }; 642 643 struct WarpOpBroadcast : public OpRewritePattern<WarpExecuteOnLane0Op> { 644 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 645 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 646 PatternRewriter &rewriter) const override { 647 OpOperand *operand = getWarpResult( 648 warpOp, [](Operation *op) { return isa<vector::BroadcastOp>(op); }); 649 if (!operand) 650 return failure(); 651 unsigned int operandNumber = operand->getOperandNumber(); 652 auto broadcastOp = operand->get().getDefiningOp<vector::BroadcastOp>(); 653 Location loc = broadcastOp.getLoc(); 654 auto destVecType = 655 warpOp->getResultTypes()[operandNumber].cast<VectorType>(); 656 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 657 rewriter, warpOp, {broadcastOp.getSource()}, 658 {broadcastOp.getSource().getType()}); 659 rewriter.setInsertionPointAfter(newWarpOp); 660 Value broadcasted = rewriter.create<vector::BroadcastOp>( 661 loc, destVecType, newWarpOp->getResults().back()); 662 newWarpOp->getResult(operandNumber).replaceAllUsesWith(broadcasted); 663 return success(); 664 } 665 }; 666 667 /// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if 668 /// the scf.ForOp is the last operation in the region so that it doesn't change 669 /// the order of execution. This creates a new scf.for region after the 670 /// WarpExecuteOnLane0Op. The new scf.for region will contain a new 671 /// WarpExecuteOnLane0Op region. Example: 672 /// ``` 673 /// %w = vector.warp_execute_on_lane_0(%laneid) -> (vector<4xf32>) { 674 /// ... 675 /// %v1 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %v) 676 /// -> (vector<128xf32>) { 677 /// ... 678 /// scf.yield %r : vector<128xf32> 679 /// } 680 /// vector.yield %v1 : vector<128xf32> 681 /// } 682 /// ``` 683 /// To: 684 /// %w0 = vector.warp_execute_on_lane_0(%arg0) -> (vector<4xf32>) { 685 /// ... 686 /// vector.yield %v : vector<128xf32> 687 /// } 688 /// %w = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%varg = %q0) 689 /// -> (vector<4xf32>) { 690 /// %iw = vector.warp_execute_on_lane_0(%laneid) 691 /// args(%varg : vector<4xf32>) -> (vector<4xf32>) { 692 /// ^bb0(%arg: vector<128xf32>): 693 /// ... 694 /// vector.yield %ir : vector<128xf32> 695 /// } 696 /// scf.yield %iw : vector<4xf32> 697 /// } 698 /// ``` 699 struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> { 700 using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern; 701 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 702 PatternRewriter &rewriter) const override { 703 auto yield = cast<vector::YieldOp>( 704 warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); 705 // Only pick up forOp if it is the last op in the region. 706 Operation *lastNode = yield->getPrevNode(); 707 auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); 708 if (!forOp) 709 return failure(); 710 SmallVector<Value> newOperands; 711 SmallVector<unsigned> resultIdx; 712 // Collect all the outputs coming from the forOp. 713 for (OpOperand &yieldOperand : yield->getOpOperands()) { 714 if (yieldOperand.get().getDefiningOp() != forOp.getOperation()) 715 continue; 716 auto forResult = yieldOperand.get().cast<OpResult>(); 717 newOperands.push_back(warpOp.getResult(yieldOperand.getOperandNumber())); 718 yieldOperand.set(forOp.getIterOperands()[forResult.getResultNumber()]); 719 resultIdx.push_back(yieldOperand.getOperandNumber()); 720 } 721 OpBuilder::InsertionGuard g(rewriter); 722 rewriter.setInsertionPointAfter(warpOp); 723 // Create a new for op outside the region with a WarpExecuteOnLane0Op region 724 // inside. 725 auto newForOp = rewriter.create<scf::ForOp>( 726 forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), 727 forOp.getStep(), newOperands); 728 rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); 729 auto innerWarp = rewriter.create<WarpExecuteOnLane0Op>( 730 warpOp.getLoc(), newForOp.getResultTypes(), warpOp.getLaneid(), 731 warpOp.getWarpSize(), newForOp.getRegionIterArgs(), 732 forOp.getResultTypes()); 733 734 SmallVector<Value> argMapping; 735 argMapping.push_back(newForOp.getInductionVar()); 736 for (Value args : innerWarp.getBody()->getArguments()) { 737 argMapping.push_back(args); 738 } 739 SmallVector<Value> yieldOperands; 740 for (Value operand : forOp.getBody()->getTerminator()->getOperands()) 741 yieldOperands.push_back(operand); 742 rewriter.eraseOp(forOp.getBody()->getTerminator()); 743 rewriter.mergeBlocks(forOp.getBody(), innerWarp.getBody(), argMapping); 744 rewriter.setInsertionPoint(innerWarp.getBody(), innerWarp.getBody()->end()); 745 rewriter.create<vector::YieldOp>(innerWarp.getLoc(), yieldOperands); 746 rewriter.setInsertionPointAfter(innerWarp); 747 if (!innerWarp.getResults().empty()) 748 rewriter.create<scf::YieldOp>(forOp.getLoc(), innerWarp.getResults()); 749 rewriter.eraseOp(forOp); 750 // Replace the warpOp result coming from the original ForOp. 751 for (const auto &res : llvm::enumerate(resultIdx)) { 752 warpOp.getResult(res.value()) 753 .replaceAllUsesWith(newForOp.getResult(res.index())); 754 newForOp->setOperand(res.index() + 3, warpOp.getResult(res.value())); 755 } 756 return success(); 757 } 758 }; 759 760 /// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op. 761 /// The vector is reduced in parallel. Currently limited to vector size matching 762 /// the warpOp size. E.g.: 763 /// ``` 764 /// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) { 765 /// %0 = "some_def"() : () -> (vector<32xf32>) 766 /// %1 = vector.reduction "add", %0 : vector<32xf32> into f32 767 /// vector_ext.yield %1 : f32 768 /// } 769 /// ``` 770 /// is lowered to: 771 /// ``` 772 /// %0 = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) { 773 /// %1 = "some_def"() : () -> (vector<32xf32>) 774 /// vector_ext.yield %1 : vector<32xf32> 775 /// } 776 /// %a = vector.extract %0[0] : vector<1xf32> 777 /// %r = ("warp.reduction %a") 778 /// ``` 779 struct WarpOpReduction : public OpRewritePattern<WarpExecuteOnLane0Op> { 780 WarpOpReduction(MLIRContext *context, 781 DistributedReductionFn distributedReductionFn, 782 PatternBenefit benefit = 1) 783 : OpRewritePattern<WarpExecuteOnLane0Op>(context, benefit), 784 distributedReductionFn(distributedReductionFn) {} 785 786 LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, 787 PatternRewriter &rewriter) const override { 788 OpOperand *yieldOperand = getWarpResult( 789 warpOp, [](Operation *op) { return isa<vector::ReductionOp>(op); }); 790 if (!yieldOperand) 791 return failure(); 792 793 auto reductionOp = 794 cast<vector::ReductionOp>(yieldOperand->get().getDefiningOp()); 795 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); 796 // Only rank 1 vectors supported. 797 if (vectorType.getRank() != 1) 798 return rewriter.notifyMatchFailure( 799 warpOp, "Only rank 1 reductions can be distributed."); 800 // Only warp_size-sized vectors supported. 801 if (static_cast<uint64_t>(vectorType.getShape()[0]) != warpOp.getWarpSize()) 802 return rewriter.notifyMatchFailure( 803 warpOp, "Reduction vector dimension must match was size."); 804 // Only f32 and i32 element types are supported. 805 if (!reductionOp.getType().isF32() && 806 !reductionOp.getType().isSignlessInteger(32)) 807 return rewriter.notifyMatchFailure( 808 warpOp, 809 "Reduction distribution currently only supports 32bits types."); 810 811 Location yieldLoc = yieldOperand->getOwner()->getLoc(); 812 813 // Return vector that will be reduced from the WarpExecuteOnLane0Op. 814 unsigned operandIndex = yieldOperand->getOperandNumber(); 815 SmallVector<Value> yieldValues = {reductionOp.getVector()}; 816 SmallVector<Type> retTypes = {VectorType::get({1}, reductionOp.getType())}; 817 unsigned numResults = warpOp.getNumResults(); 818 WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns( 819 rewriter, warpOp, yieldValues, retTypes); 820 rewriter.setInsertionPointAfter(newWarpOp); 821 822 // Every lane has one scalar value. These should be reduced. 823 Value laneValVec = newWarpOp.getResult(numResults); 824 Value laneVal = rewriter.create<vector::ExtractOp>(yieldLoc, laneValVec, 0); 825 laneVal = 826 distributedReductionFn(reductionOp.getLoc(), rewriter, laneVal, 827 reductionOp.getKind(), newWarpOp.getWarpSize()); 828 newWarpOp.getResult(operandIndex).replaceAllUsesWith(laneVal); 829 return success(); 830 } 831 832 private: 833 DistributedReductionFn distributedReductionFn; 834 }; 835 836 } // namespace 837 838 void mlir::vector::populateWarpExecuteOnLane0OpToScfForPattern( 839 RewritePatternSet &patterns, 840 const WarpExecuteOnLane0LoweringOptions &options) { 841 patterns.add<WarpOpToScfForPattern>(patterns.getContext(), options); 842 } 843 844 void mlir::vector::populateDistributeTransferWriteOpPatterns( 845 RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn) { 846 patterns.add<WarpOpTransferWrite>(patterns.getContext(), distributionMapFn); 847 } 848 849 void mlir::vector::populatePropagateWarpVectorDistributionPatterns( 850 RewritePatternSet &patterns) { 851 patterns.add<WarpOpElementwise, WarpOpTransferRead, WarpOpDeadResult, 852 WarpOpBroadcast, WarpOpForwardOperand, WarpOpScfForOp>( 853 patterns.getContext()); 854 } 855 856 void mlir::vector::populateDistributeReduction( 857 RewritePatternSet &patterns, 858 DistributedReductionFn distributedReductionFn) { 859 patterns.add<WarpOpReduction>(patterns.getContext(), distributedReductionFn); 860 } 861 862 void mlir::vector::moveScalarUniformCode(WarpExecuteOnLane0Op warpOp) { 863 Block *body = warpOp.getBody(); 864 865 // Keep track of the ops we want to hoist. 866 llvm::SmallSetVector<Operation *, 8> opsToMove; 867 868 // Helper to check if a value is or will be defined outside of the region. 869 auto isDefinedOutsideOfBody = [&](Value value) { 870 auto *definingOp = value.getDefiningOp(); 871 return (definingOp && opsToMove.count(definingOp)) || 872 warpOp.isDefinedOutsideOfRegion(value); 873 }; 874 875 // Do not use walk here, as we do not want to go into nested regions and hoist 876 // operations from there. 877 for (auto &op : body->without_terminator()) { 878 bool hasVectorResult = llvm::any_of(op.getResults(), [](Value result) { 879 return result.getType().isa<VectorType>(); 880 }); 881 if (!hasVectorResult && canBeHoisted(&op, isDefinedOutsideOfBody)) 882 opsToMove.insert(&op); 883 } 884 885 // Move all the ops marked as uniform outside of the region. 886 for (Operation *op : opsToMove) 887 op->moveBefore(warpOp); 888 } 889