1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Vector/VectorOps.h" 24 #include "mlir/Dialect/Vector/VectorUtils.h" 25 #include "mlir/IR/BuiltinOps.h" 26 #include "mlir/IR/Dominance.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/LoopUtils.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/Support/Debug.h" 31 32 using llvm::dbgs; 33 34 #define DEBUG_TYPE "linalg-hoisting" 35 36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 namespace { 42 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 43 /// instructions that need to be hoisted too. 44 struct HoistableWrite { 45 vector::TransferWriteOp transferWriteOp; 46 tensor::InsertSliceOp insertSliceOp; 47 }; 48 /// Represents a unit of hoistable TransferReadOp. This may comprise other 49 /// instructions that need to be hoisted too. 50 struct HoistableRead { 51 vector::TransferReadOp transferReadOp; 52 tensor::ExtractSliceOp extractSliceOp; 53 }; 54 } // namespace 55 56 /// Return true if op1 and op2 are the same constant or the same SSA value. 57 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 58 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 59 Attribute attr = ofr.dyn_cast<Attribute>(); 60 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 61 if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>()) 62 attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue(); 63 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 64 return intAttr.getValue().getSExtValue(); 65 return llvm::None; 66 }; 67 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 68 if (cst1 && cst2 && *cst1 == *cst2) 69 return true; 70 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 71 return v1 && v2 && v1 == v2; 72 } 73 74 /// Return true is all offsets, sizes and strides are equal. 75 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, 76 tensor::InsertSliceOp si) { 77 if (s.static_offsets().size() != si.static_offsets().size()) 78 return false; 79 if (s.static_sizes().size() != si.static_sizes().size()) 80 return false; 81 if (s.static_strides().size() != si.static_strides().size()) 82 return false; 83 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 84 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 85 return false; 86 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 87 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 88 return false; 89 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 90 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 91 return false; 92 return true; 93 } 94 95 /// Look for a HoistableRead, in the given tensor uses, accessing the same 96 /// offset as the HoistableWrite. 97 static HoistableRead findMatchingTransferRead(HoistableWrite write, 98 Value srcTensor) { 99 assert(write.transferWriteOp && 100 "expected hoistable write to have a .transfer_write"); 101 102 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 103 << *write.transferWriteOp.getOperation() << "\n"); 104 if (write.insertSliceOp) 105 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " 106 << *write.insertSliceOp.getOperation() << "\n"); 107 108 for (Operation *user : srcTensor.getUsers()) { 109 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 110 << "\n"); 111 112 // If HoistableWrite involves a InsertSliceOp, we need to find a 113 // matching ExtractSliceOp. 114 tensor::ExtractSliceOp sliceOp; 115 Operation *maybeTransferReadUser = user; 116 if (write.insertSliceOp) { 117 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 118 if (!sliceOp || sliceOp.getResult().getType() != 119 write.insertSliceOp.source().getType()) 120 continue; 121 122 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 123 << *sliceOp << " vs " << *write.insertSliceOp << "\n"); 124 if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) 125 continue; 126 127 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 128 // If we got here, sliceOp is hoistable iff it has exactly 2 uses: 129 // 1. the transfer_write we want to hoist. 130 // 2. a matching transfer_read. 131 // Anything else, we skip. 132 bool skip = false; 133 Operation *otherUser = nullptr; 134 for (Operation *u : sliceOp->getUsers()) { 135 if (u == write.transferWriteOp) 136 continue; 137 if (otherUser) { 138 skip = true; 139 break; 140 } 141 otherUser = u; 142 } 143 if (skip || !otherUser) 144 continue; 145 maybeTransferReadUser = otherUser; 146 } 147 148 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 149 << "\n"); 150 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 151 if (read && read.indices() == write.transferWriteOp.indices() && 152 read.getVectorType() == write.transferWriteOp.getVectorType()) 153 return HoistableRead{read, sliceOp}; 154 } 155 return HoistableRead(); 156 } 157 158 /// Check if the chunk of data inserted by the HoistableWrite are read by any 159 /// other op than the HoistableRead candidate. 160 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 161 HoistableRead candidateRead, 162 BlockArgument tensorArg) { 163 // Make sure none of the other uses read the part of the tensor modified 164 // by the transfer_write. 165 llvm::SmallVector<Value::use_range, 1> uses; 166 uses.push_back(tensorArg.getUses()); 167 while (!uses.empty()) { 168 for (OpOperand &use : uses.pop_back_val()) { 169 Operation *user = use.getOwner(); 170 // Skip the candidate use, only inspect the "other" uses. 171 if (user == candidateRead.transferReadOp || 172 user == candidateRead.extractSliceOp || 173 user == write.transferWriteOp || user == write.insertSliceOp) 174 continue; 175 // Consider all transitive uses through a extract_slice / insert_slice. 176 // TODO: atm we just bail because a stronger analysis is needed for these 177 // cases. 178 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user)) 179 return true; 180 // Consider all transitive uses through a vector.transfer_write. 181 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 182 uses.push_back(writeUser->getResult(0).getUses()); 183 continue; 184 } 185 // Consider all nested uses through an scf::ForOp. We may have 186 // pass-through tensor arguments left from previous level of 187 // hoisting. 188 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 189 Value arg = forUser.getLoopBody().getArgument( 190 use.getOperandNumber() - forUser.getNumControlOperands() + 191 /*iv value*/ 1); 192 uses.push_back(arg.getUses()); 193 continue; 194 } 195 // Follow the use yield as long as it doesn't escape the original 196 // region. 197 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 198 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 199 yieldUser->getParentOp())) { 200 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 201 uses.push_back(ret.getUses()); 202 continue; 203 } 204 auto read = dyn_cast<vector::TransferReadOp>(user); 205 if (!read || !isDisjointTransferIndices( 206 cast<VectorTransferOpInterface>(read.getOperation()), 207 cast<VectorTransferOpInterface>( 208 write.transferWriteOp.getOperation()))) { 209 return true; 210 } 211 } 212 } 213 return false; 214 } 215 216 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 217 /// Return the null HoistableWrite() if it is not comprised of a 218 /// vector.transfer_write + optional insert_slice or if any of the indexings 219 /// is `forOp`-dependent. 220 static HoistableWrite 221 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 222 OpOperand &yieldOperand) { 223 Value v = yieldOperand.get(); 224 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 225 // Indexing must not depend on `forOp`. 226 for (Value operand : write.indices()) 227 if (!forOp.isDefinedOutsideOfLoop(operand)) 228 return HoistableWrite(); 229 230 return HoistableWrite{write, nullptr}; 231 } 232 233 if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) { 234 // Inserted slice must come from vector.transfer_write. 235 auto write = 236 insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>(); 237 if (!write) 238 return HoistableWrite(); 239 240 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 241 auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>(); 242 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 243 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 244 return HoistableWrite(); 245 246 // Indexing inserted into must not depend on `forOp`. 247 for (Value operand : insertSliceOp->getOperands().drop_front( 248 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) 249 if (!forOp.isDefinedOutsideOfLoop(operand)) 250 return HoistableWrite(); 251 252 return HoistableWrite{write, insertSliceOp}; 253 } 254 255 return HoistableWrite(); 256 } 257 258 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 259 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 260 BlockArgument tensorBBArg) { 261 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 262 assert(read.transferReadOp && write.transferWriteOp && 263 "expected transfer_read and transfer_write ops to be set"); 264 assert(((read.extractSliceOp && write.insertSliceOp) || 265 (!read.extractSliceOp && !write.insertSliceOp)) && 266 "expected matching extract_slice / insert_slice"); 267 LLVM_DEBUG(DBGS() << "In forOp:\n" 268 << *forOp.getOperation() 269 << "\nHoist: " << *read.transferReadOp.getOperation() 270 << "\nHoist: " << *write.transferWriteOp.getOperation() 271 << "\nInvolving: " << tensorBBArg << "\n"); 272 273 // If a read slice is present, hoist it. 274 if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp}))) 275 llvm_unreachable("Unexpected failure moving extract_slice out of loop"); 276 277 // Hoist the transfer_read op. 278 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 279 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 280 281 // TODO: don't hardcode /*numIvs=*/1. 282 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 283 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 284 285 // Update the source tensor. 286 if (read.extractSliceOp) 287 read.extractSliceOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 288 else 289 read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 290 291 // Hoist write after. 292 if (write.insertSliceOp) 293 write.insertSliceOp->moveAfter(forOp); 294 write.transferWriteOp->moveAfter(forOp); 295 296 // Update the yield. 297 auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator()); 298 if (write.insertSliceOp) 299 yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest()); 300 else 301 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 302 303 // Rewrite `loop` with additional new yields. 304 OpBuilder b(read.transferReadOp); 305 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 306 write.transferWriteOp.vector()); 307 // Transfer write has been hoisted, need to update the vector and tensor 308 // source. Replace the result of the loop to use the new tensor created 309 // outside the loop. 310 // Depending on whether a insert_slice is present or not, it carries the 311 // update on the tensor operands. 312 if (write.insertSliceOp) { 313 newForOp.getResult(initArgNumber) 314 .replaceAllUsesWith(write.insertSliceOp.getResult()); 315 write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result()); 316 write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); 317 } else { 318 newForOp.getResult(initArgNumber) 319 .replaceAllUsesWith(write.transferWriteOp.getResult(0)); 320 write.transferWriteOp.sourceMutable().assign( 321 newForOp.getResult(initArgNumber)); 322 } 323 324 // Always update with the newly yield tensor and vector. 325 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 326 } 327 328 // To hoist transfer op on tensor the logic can be significantly simplified 329 // compared to the case on buffer. The transformation follows this logic: 330 // 1. Look for transfer_write with a single use from ForOp yield 331 // 2. Check the uses of the matching block argument and look for a transfer_read 332 // with the same indices. 333 // 3. Check that all the other uses of the tensor argument are either disjoint 334 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 335 // the new tensor has the same restrictions on its uses. 336 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 337 // After this transformation the scf.forOp may have unused arguments that can be 338 // remove by the canonicalization pass. 339 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 340 bool changed = true; 341 while (changed) { 342 changed = false; 343 func.walk([&](scf::ForOp forOp) { 344 Operation *yield = forOp.getBody()->getTerminator(); 345 for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { 346 OpOperand &ret = yield->getOpOperand(it.index()); 347 HoistableWrite write = 348 getLoopInvariantTransferWriteOpDefining(forOp, ret); 349 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 350 continue; 351 LLVM_DEBUG(dbgs() << "\n"; 352 DBGS() << "Candidate write for hoisting: " 353 << *write.transferWriteOp.getOperation() << "\n"); 354 if (write.insertSliceOp) 355 LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " 356 << *write.insertSliceOp.getOperation() << "\n"); 357 if (llvm::any_of(write.transferWriteOp.indices(), 358 [&forOp](Value index) { 359 return !forOp.isDefinedOutsideOfLoop(index); 360 })) 361 continue; 362 // Find a read with the same type and indices. 363 HoistableRead matchingRead = 364 findMatchingTransferRead(write, it.value()); 365 // Make sure none of the other uses read the part of the tensor modified 366 // by the transfer_write. 367 if (!matchingRead.transferReadOp || 368 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 369 continue; 370 371 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 372 hoistReadWrite(matchingRead, write, it.value()); 373 changed = true; 374 forOp.erase(); 375 376 // Need to interrupt and restart: erasing the loop messes up the walk. 377 return WalkResult::interrupt(); 378 } 379 return WalkResult::advance(); 380 }); 381 // Apply canonicalization so the newForOp + yield folds immediately, thus 382 // cleaning up the IR and potentially enabling more hoisting. 383 if (changed) { 384 RewritePatternSet patterns(func->getContext()); 385 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 386 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 387 } 388 } 389 } 390 391 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 392 bool changed = true; 393 while (changed) { 394 changed = false; 395 396 func.walk([&](vector::TransferReadOp transferRead) { 397 if (!transferRead.getShapedType().isa<MemRefType>()) 398 return WalkResult::advance(); 399 400 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 401 << *transferRead.getOperation() << "\n"); 402 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 403 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 404 << "\n"); 405 if (!loop) 406 return WalkResult::advance(); 407 408 if (failed(moveLoopInvariantCode( 409 cast<LoopLikeOpInterface>(loop.getOperation())))) 410 llvm_unreachable( 411 "Unexpected failure to move invariant code out of loop"); 412 413 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 414 << "\n"); 415 416 SetVector<Operation *> forwardSlice; 417 getForwardSlice(transferRead.getOperation(), &forwardSlice); 418 419 // Look for the last TransferWriteOp in the forwardSlice of 420 // `transferRead` that operates on the same memref. 421 vector::TransferWriteOp transferWrite; 422 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 423 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 424 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 425 continue; 426 transferWrite = candidateWrite; 427 } 428 429 // All operands of the TransferRead must be defined outside of the loop. 430 for (auto operand : transferRead.getOperands()) 431 if (!loop.isDefinedOutsideOfLoop(operand)) 432 return WalkResult::advance(); 433 434 // Only hoist transfer_read / transfer_write pairs for now. 435 if (!transferWrite) 436 return WalkResult::advance(); 437 438 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 439 << "\n"); 440 441 // Approximate aliasing by checking that: 442 // 1. indices are the same, 443 // 2. no other operations in the loop access the same memref except 444 // for transfer_read/transfer_write accessing statically disjoint 445 // slices. 446 if (transferRead.indices() != transferWrite.indices() && 447 transferRead.getVectorType() == transferWrite.getVectorType()) 448 return WalkResult::advance(); 449 450 // TODO: may want to memoize this information for performance but it 451 // likely gets invalidated often. 452 DominanceInfo dom(loop); 453 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 454 return WalkResult::advance(); 455 for (auto &use : transferRead.source().getUses()) { 456 if (!dom.properlyDominates(loop, use.getOwner())) 457 continue; 458 if (use.getOwner() == transferRead.getOperation() || 459 use.getOwner() == transferWrite.getOperation()) 460 continue; 461 if (auto transferWriteUse = 462 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 463 if (!isDisjointTransferSet( 464 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 465 cast<VectorTransferOpInterface>( 466 transferWriteUse.getOperation()))) 467 return WalkResult::advance(); 468 } else if (auto transferReadUse = 469 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 470 if (!isDisjointTransferSet( 471 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 472 cast<VectorTransferOpInterface>( 473 transferReadUse.getOperation()))) 474 return WalkResult::advance(); 475 } else { 476 // Unknown use, we cannot prove that it doesn't alias with the 477 // transferRead/transferWrite operations. 478 return WalkResult::advance(); 479 } 480 } 481 482 // Hoist read before. 483 if (failed(loop.moveOutOfLoop({transferRead}))) 484 llvm_unreachable( 485 "Unexpected failure to move transfer read out of loop"); 486 487 // Hoist write after. 488 transferWrite->moveAfter(loop); 489 490 // Rewrite `loop` with new yields by cloning and erase the original loop. 491 OpBuilder b(transferRead); 492 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 493 transferWrite.vector()); 494 495 // Transfer write has been hoisted, need to update the written value to 496 // the value yielded by the newForOp. 497 transferWrite.vector().replaceAllUsesWith( 498 newForOp.getResults().take_back()[0]); 499 500 changed = true; 501 loop.erase(); 502 // Need to interrupt and restart because erasing the loop messes up the 503 // walk. 504 return WalkResult::interrupt(); 505 }); 506 } 507 } 508 509 /// Return success if `v` is a value that is only transitively defined by ops of 510 /// type in `OpTypeList`. 511 template <typename... OpTypeList> 512 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) { 513 // Compute a backward slice up to, but not including, `outerLimit`. 514 SetVector<Operation *> backwardSlice; 515 getBackwardSlice(v, &backwardSlice, [&](Operation *op) { 516 return outerLimit->isProperAncestor(op); 517 }); 518 // Traverse the backward slice and ensure we can perform the computation to 519 // hoist. 520 for (Operation *op : backwardSlice) { 521 if (isa<OpTypeList...>(op)) 522 continue; 523 LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n"); 524 return false; 525 } 526 return true; 527 } 528 529 bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) { 530 return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>(); 531 } 532 533 /// Compute the tightest lower bound with quantities that are all defined 534 /// outside of `outer`. 535 /// Return null if such a bound cannot be computed. 536 Value computeLoopIndependentLowerBound(OpBuilder &b, scf::ForOp outer, 537 Value v) { 538 if (isDefinedOutsideOrConstant(outer, v)) 539 return v; 540 return Value(); 541 } 542 543 /// Compute the tightest upper bound with quantities that are all defined 544 /// outside of `outer`. 545 /// Expects all ops in the backward slice of `v` up to `outer` to be either 546 /// scf.for, affine.min or affine.apply. 547 static Value computeLoopIndependentUpperBound(OpBuilder &b, scf::ForOp outer, 548 Value v) { 549 if (isDefinedOutsideOrConstant(outer, v)) 550 return v; 551 552 LLVM_DEBUG(DBGS() << "Begin loopIndependentUpperBound for: " << v << "\n"); 553 554 bool ok = 555 backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, AffineApplyOp>( 556 outer, v); 557 assert(ok && "expected to only be defined by scf::ForOp and AffineMinOp"); 558 (void)ok; 559 560 // Compute a backward slice up to, but not including, `outer`. 561 SetVector<Operation *> backwardSlice; 562 getBackwardSlice(v, &backwardSlice, 563 [&](Operation *op) { return outer->isProperAncestor(op); }); 564 backwardSlice.insert(v.getDefiningOp()); 565 566 OpBuilder::InsertionGuard g(b); 567 b.setInsertionPoint(outer); 568 Value res = v; 569 BlockAndValueMapping bvm; 570 for (Operation *op : backwardSlice) { 571 if (isa<scf::ForOp>(op)) 572 continue; 573 if (isa<AffineApplyOp>(op)) { 574 b.clone(*op, bvm); 575 continue; 576 } 577 auto sliceMinOp = cast<AffineMinOp>(op); 578 GetMinMaxExprFn getSCFMinMax = [&](Value value, 579 SmallVectorImpl<Value> &dims, 580 SmallVectorImpl<Value> &symbols) { 581 return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) { 582 return outer->isAncestor(op); 583 }); 584 }; 585 // Perform the substitution of the operands of AffineMinOp. 586 auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax); 587 SmallVector<Value> resultOperands = mapAndOperands.dims; 588 llvm::append_range(resultOperands, mapAndOperands.symbols); 589 AffineMap map = mapAndOperands.map; 590 canonicalizeMapAndOperands(&map, &resultOperands); 591 map = simplifyAffineMap(map); 592 res = b.create<AffineMinOp>( 593 outer->getLoc(), map, 594 llvm::to_vector<4>(llvm::map_range(resultOperands, [&](Value operand) { 595 return bvm.lookupOrDefault(operand); 596 }))); 597 bvm.map(sliceMinOp, res); 598 } 599 LLVM_DEBUG(DBGS() << "End loopIndependentUpperBound with: " << res << "\n"); 600 return res; 601 } 602 603 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step). 604 /// The returned Value is guaranteed not to depend on any loop comprised in 605 /// [`outer`, `forOp`]. 606 /// Return null if such a loop-independent quantity cannot be computed. 607 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp outer, 608 scf::ForOp forOp) { 609 MLIRContext *ctx = forOp->getContext(); 610 AffineExpr lb, ub, step; 611 bindDims(ctx, lb, ub); 612 bindSymbols(ctx, step); 613 Value lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()), 614 ubVal = computeLoopIndependentUpperBound(b, outer, forOp.upperBound()), 615 stepVal = forOp.step(); 616 if (!lbVal || !ubVal || !stepVal) 617 return Value(); 618 auto loc = forOp->getLoc(); 619 Value res = b.create<AffineApplyOp>(loc, (ub - lb).ceilDiv(step), 620 ValueRange{lbVal, ubVal, stepVal}); 621 return res; 622 } 623 624 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). 625 /// The returned Value is guaranteed not to depend on any loop comprised in 626 /// [`outer`, `forOp`]. 627 /// Return null if such a loop-independent quantity cannot be computed. 628 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer, 629 scf::ForOp forOp) { 630 MLIRContext *ctx = forOp->getContext(); 631 AffineExpr iv, lb, step; 632 bindDims(ctx, iv, lb); 633 bindSymbols(ctx, step); 634 Value ivVal = forOp.getInductionVar(), 635 lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()), 636 stepVal = forOp.step(); 637 if (!ivVal || !lbVal || !stepVal) 638 return Value(); 639 auto loc = forOp->getLoc(); 640 return b.create<AffineApplyOp>(loc, (iv - lb).ceilDiv(step), 641 ValueRange{ivVal, lbVal, stepVal}); 642 } 643 644 /// Ensure prerequisites that guarantee pad op hoisting can occur. 645 /// Return failure in the cases when we cannot perform hoisting; i.e. if either: 646 /// 1. There exists a use of `padTensorOp` that is not a linalg input operand. 647 /// 2. There isn't an enclosing `outermostEnclosingForOp` loop. 648 /// 3. There exists an op with a region that is dominated by 649 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a 650 /// LinalgOp. 651 /// 4. There exists an op with side effects that is dominated by 652 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface. 653 /// 5. The lower bound, upper bound and step of all the loops involved in the 654 /// hoisting can be 655 /// 656 /// While ensuring prerequisites: 657 /// 1. Fill the `backwardSlice` to contain the topologically sorted ops 658 /// dominated by `outermostEnclosingForOp`. 659 /// 2. Fill the `packingLoops` to contain only the enclosing loops of 660 /// `backwardSlice` whose IV is actually used in computing padding. Loops that 661 /// remain in `backwardSlice` but that are not in `packingLoops` are 662 /// dimensions of reuse. 663 static LogicalResult 664 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels, 665 SetVector<Operation *> &backwardSlice, 666 SetVector<Operation *> &packingLoops, 667 SmallVector<Value> &dynamicTensorSizes) { 668 // Bail on any use that isn't an input of a Linalg op. 669 // Hoisting of inplace updates happens after vectorization. 670 for (OpOperand &use : padTensorOp.result().getUses()) { 671 auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner()); 672 if (!linalgUser || !linalgUser.isInputTensor(&use)) 673 return failure(); 674 } 675 676 // Get at most nLevels of enclosing loops. 677 SmallVector<LoopLikeOpInterface> reverseEnclosingLoops; 678 Operation *outermostEnclosingForOp = nullptr, 679 *nextEnclosingForOp = 680 padTensorOp->getParentOfType<LoopLikeOpInterface>(); 681 while (nLevels-- > 0 && nextEnclosingForOp) { 682 outermostEnclosingForOp = nextEnclosingForOp; 683 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 684 nextEnclosingForOp = 685 nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>(); 686 } 687 if (!outermostEnclosingForOp) 688 return failure(); 689 690 // Get the backwards slice from `padTensorOp` that is dominated by the 691 // outermost enclosing loop. 692 DominanceInfo domInfo(outermostEnclosingForOp); 693 getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, 694 [&](Operation *op) { 695 return domInfo.dominates(outermostEnclosingForOp, op); 696 }); 697 698 // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp. 699 if (llvm::any_of(backwardSlice, [](Operation *op) { 700 return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) && 701 !isa<LinalgOp>(op); 702 })) 703 return failure(); 704 705 // Filter out the loops whose induction variable is not used to compute the 706 // padded result. As a first approximation, just look for IVs that have no use 707 // in the backwardSlice. 708 // These are the dimensions of reuse that we can exploit to reduce the amount 709 // of work / memory. 710 // TODO: would this optimization compose better as a canonicalization? 711 for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) { 712 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); 713 if (!forOp) 714 continue; 715 for (Operation *user : forOp.getInductionVar().getUsers()) { 716 if (backwardSlice.contains(user)) { 717 packingLoops.insert(forOp); 718 break; 719 } 720 } 721 } 722 723 // Backward slice is a topologically sorted list of ops starting at 724 // `outermostEnclosingForOp`. 725 assert(outermostEnclosingForOp == backwardSlice.front()); 726 727 scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp); 728 if (llvm::any_of(packingLoops, [&](Operation *op) { 729 scf::ForOp forOp = cast<scf::ForOp>(op); 730 Value lb = forOp.lowerBound(), ub = forOp.upperBound(), 731 step = forOp.step(); 732 return !isDefinedOutsideOrConstant(outer, lb) || 733 !(isDefinedOutsideOrConstant(outer, ub) || 734 backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, 735 AffineApplyOp>(outer, ub)) || 736 !isDefinedOutsideOrConstant(outer, step); 737 })) 738 return failure(); 739 740 // IP just before the outermost loop considered that we hoist above. 741 OpBuilder b(outermostEnclosingForOp); 742 dynamicTensorSizes = 743 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) { 744 return buildLoopTripCount(b, cast<scf::ForOp>(outermostEnclosingForOp), 745 cast<scf::ForOp>(op)); 746 })); 747 // Assert all loop trip counts can be computed. 748 if (!llvm::all_of(dynamicTensorSizes, [](Value v) { return v; })) 749 llvm_unreachable("loop independence prerequisite not met"); 750 return success(); 751 } 752 753 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, 754 unsigned nLoops) { 755 SmallVector<Value> dynamicTensorSizes; 756 SetVector<Operation *> backwardSlice, packingLoops; 757 if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops, 758 backwardSlice, packingLoops, 759 dynamicTensorSizes))) 760 return failure(); 761 762 // Update actual number of loops, which may be smaller. 763 nLoops = packingLoops.size(); 764 765 Location loc = padTensorOp->getLoc(); 766 RankedTensorType paddedTensorType = padTensorOp.getResultType(); 767 unsigned paddedRank = paddedTensorType.getRank(); 768 769 // Backward slice is a topologically sorted list of ops starting at 770 // `outermostEnclosingForOp`. 771 Operation *outermostEnclosingForOp = backwardSlice.front(); 772 // IP just before the outermost loop considered that we hoist above. 773 OpBuilder b(outermostEnclosingForOp); 774 775 // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize 776 // padding. 777 SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize); 778 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 779 // tensor. 780 llvm::append_range(packedShape, paddedTensorType.getShape()); 781 auto packedTensorType = 782 RankedTensorType::get(packedShape, paddedTensorType.getElementType()); 783 Value packedTensor = b.create<linalg::InitTensorOp>( 784 loc, dynamicTensorSizes, packedTensorType.getShape(), 785 packedTensorType.getElementType()); 786 787 // Clone the operations involved in the backward slice, iteratively stepping 788 // into the loops that we encounter. 789 // The implementation proceeds in a stack-like fashion: 790 // 1. Iteratively clone and step into the loops, pushing the `packedTensor` 791 // deeper in the stack. 792 // 2. Create a InsertSliceOp at the top of the stack. 793 // 3. Iteratively pop and yield the result of the InsertSliceOp across 794 // the cloned loops. 795 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; 796 clonedLoopIvs.reserve(nLoops); 797 leadingPackedTensorIndexings.reserve(nLoops); 798 BlockAndValueMapping bvm; 799 // Insert `padTensorOp` into the backwardSlice so we clone it too. 800 backwardSlice.insert(padTensorOp); 801 // Stack step 1. iteratively clone loops and push `packedTensor`. 802 for (Operation *op : backwardSlice) { 803 // Specifically sit out in the extract_slice(packedTensor) case: this is the 804 // piece we seek to replace. 805 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) 806 if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor) 807 continue; 808 auto effects = dyn_cast<MemoryEffectOpInterface>(op); 809 bool hasNoEffects = !effects || effects.hasNoEffect(); 810 if (hasNoEffects && 811 (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) { 812 b.clone(*op, bvm); 813 continue; 814 } 815 // TODO: support more cases as they appear. 816 auto forOp = dyn_cast<scf::ForOp>(op); 817 assert(forOp && "Expected scf::ForOp when hoisting pad ops"); 818 // Unused loop, just skip it. 819 if (!packingLoops.contains(forOp)) 820 continue; 821 822 auto clonedForOp = 823 b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()), 824 bvm.lookupOrDefault(forOp.upperBound()), 825 bvm.lookupOrDefault(forOp.step()), packedTensor); 826 // Map the induction var, region args and results to the `clonedForOp`. 827 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); 828 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs()); 829 bvm.map(forOp.getResults(), clonedForOp.getResults()); 830 assert(clonedForOp->getNumRegions() == 1); 831 clonedLoopIvs.push_back(clonedForOp.getInductionVar()); 832 833 b.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); 834 Value loopIndependentIterationCount = buildLoopIterationCount( 835 b, cast<scf::ForOp>(outermostEnclosingForOp), clonedForOp); 836 // Assert the loop-independent iteration count can be computed. 837 if (!loopIndependentIterationCount) 838 llvm_unreachable("loop independence prerequisite not met"); 839 leadingPackedTensorIndexings.push_back(loopIndependentIterationCount); 840 packedTensor = clonedForOp.getRegionIterArgs().front(); 841 } 842 843 // Stack step 2. create InsertSliceOp at the top of the stack. 844 // offsets = [clonedLoopIvs, 0 .. 0]. 845 SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(), 846 leadingPackedTensorIndexings.end()); 847 offsets.append(paddedRank, b.getIndexAttr(0)); 848 // sizes = [1 .. 1, paddedShape]. 849 SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1)); 850 for (int64_t sz : paddedTensorType.getShape()) { 851 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 852 // tensor. 853 assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); 854 sizes.push_back(b.getIndexAttr(sz)); 855 } 856 // strides = [1 .. 1]. 857 SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1)); 858 859 Value inserted = 860 b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()), 861 packedTensor, offsets, sizes, strides); 862 863 // Stack step 3. iteratively pop the stack and propagate the yield. 864 Value valueToYield = inserted; 865 for (Value iv : llvm::reverse(clonedLoopIvs)) { 866 auto forOp = scf::getForInductionVarOwner(iv); 867 b.setInsertionPointToEnd(&forOp.getRegion().front()); 868 b.create<scf::YieldOp>(loc, valueToYield); 869 valueToYield = forOp.getResult(0); 870 } 871 872 // Now the packed tensor is ready, replace the original padding op by a 873 // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. 874 b.setInsertionPoint(padTensorOp); 875 SmallVector<Value> loopIterationCounts = 876 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { 877 return buildLoopIterationCount( 878 b, cast<scf::ForOp>(outermostEnclosingForOp), 879 cast<scf::ForOp>(loop)); 880 })); 881 // Assert all loop iteration counts can be computed. 882 if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; })) 883 llvm_unreachable("loop independence prerequisite not met"); 884 // offsets = [originalLoopIvs, 0 .. 0]. 885 offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); 886 offsets.append(paddedRank, b.getIndexAttr(0)); 887 // sizes = [1 .. 1, paddedShape] (definedabove). 888 // strides = [1 .. 1] (defined above) 889 packedTensor = 890 scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); 891 padTensorOp.replaceAllUsesWith( 892 b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(), 893 packedTensor, offsets, sizes, strides) 894 ->getResult(0)); 895 896 Operation *toErase = padTensorOp; 897 898 // Make the newly cloned `padTensorOp` available to the caller. 899 padTensorOp = 900 cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp()); 901 902 toErase->erase(); 903 904 return success(); 905 } 906