1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Dialect/Vector/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "mlir/Transforms/LoopUtils.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 using llvm::dbgs; 30 31 #define DEBUG_TYPE "linalg-hoisting" 32 33 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 namespace { 39 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 40 /// instructions that need to be hoisted too. 41 struct HoistableWrite { 42 vector::TransferWriteOp transferWriteOp; 43 SubTensorInsertOp subTensorInsertOp; 44 }; 45 /// Represents a unit of hoistable TransferReadOp. This may comprise other 46 /// instructions that need to be hoisted too. 47 struct HoistableRead { 48 vector::TransferReadOp transferReadOp; 49 SubTensorOp subTensorOp; 50 }; 51 } // namespace 52 53 /// Return true if op1 and op2 are the same constant or the same SSA value. 54 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 55 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 56 Attribute attr = ofr.dyn_cast<Attribute>(); 57 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 58 if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>()) 59 attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue(); 60 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 61 return intAttr.getValue().getSExtValue(); 62 return llvm::None; 63 }; 64 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 65 if (cst1 && cst2 && *cst1 == *cst2) 66 return true; 67 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 68 return v1 && v2 && v1 == v2; 69 } 70 71 /// Return true is all offsets, sizes and strides are equal. 72 static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) { 73 if (s.static_offsets().size() != si.static_offsets().size()) 74 return false; 75 if (s.static_sizes().size() != si.static_sizes().size()) 76 return false; 77 if (s.static_strides().size() != si.static_strides().size()) 78 return false; 79 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 80 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 81 return false; 82 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 83 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 84 return false; 85 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 86 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 87 return false; 88 return true; 89 } 90 91 /// Look for a HoistableRead, in the given tensor uses, accessing the same 92 /// offset as the HoistableWrite. 93 static HoistableRead findMatchingTransferRead(HoistableWrite write, 94 Value srcTensor) { 95 assert(write.transferWriteOp && 96 "expected hoistable write to have a .transfer_write"); 97 98 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 99 << *write.transferWriteOp.getOperation() << "\n"); 100 if (write.subTensorInsertOp) 101 LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: " 102 << *write.subTensorInsertOp.getOperation() << "\n"); 103 104 for (Operation *user : srcTensor.getUsers()) { 105 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 106 << "\n"); 107 108 // If HoistableWrite involves a SubTensorInsertOp, we need to find a 109 // matching SubTensorOp. 110 SubTensorOp subTensorOp; 111 Operation *maybeTransferReadUser = user; 112 if (write.subTensorInsertOp) { 113 subTensorOp = dyn_cast<SubTensorOp>(user); 114 if (!subTensorOp || subTensorOp.getResult().getType() != 115 write.subTensorInsertOp.source().getType()) 116 continue; 117 118 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 119 << *subTensorOp << " vs " << *write.subTensorInsertOp 120 << "\n"); 121 if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp)) 122 continue; 123 124 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 125 // If we got here, subTensorOp is hoistable iff it has exactly 2 uses: 126 // 1. the transfer_write we want to hoist. 127 // 2. a matching transfer_read. 128 // Anything else, we skip. 129 bool skip = false; 130 Operation *otherUser = nullptr; 131 for (Operation *u : subTensorOp->getUsers()) { 132 if (u == write.transferWriteOp) 133 continue; 134 if (otherUser) { 135 skip = true; 136 break; 137 } 138 otherUser = u; 139 } 140 if (skip || !otherUser) 141 continue; 142 maybeTransferReadUser = otherUser; 143 } 144 145 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 146 << "\n"); 147 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 148 if (read && read.indices() == write.transferWriteOp.indices() && 149 read.getVectorType() == write.transferWriteOp.getVectorType()) 150 return HoistableRead{read, subTensorOp}; 151 } 152 return HoistableRead(); 153 } 154 155 /// Check if the chunk of data inserted by the HoistableWrite are read by any 156 /// other op than the HoistableRead candidate. 157 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 158 HoistableRead candidateRead, 159 BlockArgument tensorArg) { 160 // Make sure none of the other uses read the part of the tensor modified 161 // by the transfer_write. 162 llvm::SmallVector<Value::use_range, 1> uses; 163 uses.push_back(tensorArg.getUses()); 164 while (!uses.empty()) { 165 for (OpOperand &use : uses.pop_back_val()) { 166 Operation *user = use.getOwner(); 167 // Skip the candidate use, only inspect the "other" uses. 168 if (user == candidateRead.transferReadOp || 169 user == candidateRead.subTensorOp || user == write.transferWriteOp || 170 user == write.subTensorInsertOp) 171 continue; 172 // Consider all transitive uses through a subtensor / subtensor_insert. 173 // TODO: atm we just bail because a stronger analysis is needed for these 174 // cases. 175 if (isa<SubTensorOp, SubTensorInsertOp>(user)) 176 return true; 177 // Consider all transitive uses through a vector.transfer_write. 178 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 179 uses.push_back(writeUser->getResult(0).getUses()); 180 continue; 181 } 182 // Consider all nested uses through an scf::ForOp. We may have 183 // pass-through tensor arguments left from previous level of 184 // hoisting. 185 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 186 Value arg = forUser.getLoopBody().getArgument( 187 use.getOperandNumber() - forUser.getNumControlOperands() + 188 /*iv value*/ 1); 189 uses.push_back(arg.getUses()); 190 continue; 191 } 192 // Follow the use yield as long as it doesn't escape the original 193 // region. 194 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 195 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 196 yieldUser->getParentOp())) { 197 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 198 uses.push_back(ret.getUses()); 199 continue; 200 } 201 auto read = dyn_cast<vector::TransferReadOp>(user); 202 if (!read || !isDisjointTransferIndices( 203 cast<VectorTransferOpInterface>(read.getOperation()), 204 cast<VectorTransferOpInterface>( 205 write.transferWriteOp.getOperation()))) { 206 return true; 207 } 208 } 209 } 210 return false; 211 } 212 213 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 214 /// Return the null HoistableWrite() if it is not comprised of a 215 /// vector.transfer_write + optional subtensor_insert or if any of the indexings 216 /// is `forOp`-dependent. 217 static HoistableWrite 218 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 219 OpOperand &yieldOperand) { 220 Value v = yieldOperand.get(); 221 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 222 // Indexing must not depend on `forOp`. 223 for (Value operand : write.indices()) 224 if (!forOp.isDefinedOutsideOfLoop(operand)) 225 return HoistableWrite(); 226 227 return HoistableWrite{write, nullptr}; 228 } 229 230 if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) { 231 // Inserted subTensor must come from vector.transfer_write. 232 auto write = 233 subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>(); 234 if (!write) 235 return HoistableWrite(); 236 237 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 238 auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>(); 239 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 240 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 241 return HoistableWrite(); 242 243 // Indexing inserted into must not depend on `forOp`. 244 for (Value operand : subTensorInsertOp->getOperands().drop_front( 245 SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex())) 246 if (!forOp.isDefinedOutsideOfLoop(operand)) 247 return HoistableWrite(); 248 249 return HoistableWrite{write, subTensorInsertOp}; 250 } 251 252 return HoistableWrite(); 253 } 254 255 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 256 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 257 BlockArgument tensorBBArg) { 258 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 259 assert(read.transferReadOp && write.transferWriteOp && 260 "expected transfer_read and transfer_write ops to be set"); 261 assert(((read.subTensorOp && write.subTensorInsertOp) || 262 (!read.subTensorOp && !write.subTensorInsertOp)) && 263 "expected matching subtensor / subtensor_insert"); 264 LLVM_DEBUG(DBGS() << "In forOp:\n" 265 << *forOp.getOperation() 266 << "\nHoist: " << *read.transferReadOp.getOperation() 267 << "\nHoist: " << *write.transferWriteOp.getOperation() 268 << "\nInvolving: " << tensorBBArg << "\n"); 269 270 // If a read subtensor is present, hoist it. 271 if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp}))) 272 llvm_unreachable("Unexpected failure moving subtensor out of loop"); 273 274 // Hoist the transfer_read op. 275 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 276 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 277 278 // TODO: don't hardcode /*numIvs=*/1. 279 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 280 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 281 282 // Update the source tensor. 283 if (read.subTensorOp) 284 read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 285 else 286 read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 287 288 // Hoist write after. 289 if (write.subTensorInsertOp) 290 write.subTensorInsertOp->moveAfter(forOp); 291 write.transferWriteOp->moveAfter(forOp); 292 293 // Update the yield. 294 auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator()); 295 if (write.subTensorInsertOp) 296 yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest()); 297 else 298 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 299 300 // Rewrite `loop` with additional new yields. 301 OpBuilder b(read.transferReadOp); 302 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 303 write.transferWriteOp.vector()); 304 // Transfer write has been hoisted, need to update the vector and tensor 305 // source. Replace the result of the loop to use the new tensor created 306 // outside the loop. 307 // Depending on whether a subtensor_insert is present or not, it carries the 308 // update on the tensor operands. 309 if (write.subTensorInsertOp) { 310 newForOp.getResult(initArgNumber) 311 .replaceAllUsesWith(write.subTensorInsertOp.getResult()); 312 write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result()); 313 write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source()); 314 } else { 315 newForOp.getResult(initArgNumber) 316 .replaceAllUsesWith(write.transferWriteOp.getResult(0)); 317 write.transferWriteOp.sourceMutable().assign( 318 newForOp.getResult(initArgNumber)); 319 } 320 321 // Always update with the newly yield tensor and vector. 322 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 323 } 324 325 // To hoist transfer op on tensor the logic can be significantly simplified 326 // compared to the case on buffer. The transformation follows this logic: 327 // 1. Look for transfer_write with a single use from ForOp yield 328 // 2. Check the uses of the matching block argument and look for a transfer_read 329 // with the same indices. 330 // 3. Check that all the other uses of the tensor argument are either disjoint 331 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 332 // the new tensor has the same restrictions on its uses. 333 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 334 // After this transformation the scf.forOp may have unused arguments that can be 335 // remove by the canonicalization pass. 336 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 337 bool changed = true; 338 while (changed) { 339 changed = false; 340 func.walk([&](scf::ForOp forOp) { 341 Operation *yield = forOp.getBody()->getTerminator(); 342 for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { 343 OpOperand &ret = yield->getOpOperand(it.index()); 344 HoistableWrite write = 345 getLoopInvariantTransferWriteOpDefining(forOp, ret); 346 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 347 continue; 348 LLVM_DEBUG(dbgs() << "\n"; 349 DBGS() << "Candidate write for hoisting: " 350 << *write.transferWriteOp.getOperation() << "\n"); 351 if (write.subTensorInsertOp) 352 LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: " 353 << *write.subTensorInsertOp.getOperation() << "\n"); 354 if (llvm::any_of(write.transferWriteOp.indices(), 355 [&forOp](Value index) { 356 return !forOp.isDefinedOutsideOfLoop(index); 357 })) 358 continue; 359 // Find a read with the same type and indices. 360 HoistableRead matchingRead = 361 findMatchingTransferRead(write, it.value()); 362 // Make sure none of the other uses read the part of the tensor modified 363 // by the transfer_write. 364 if (!matchingRead.transferReadOp || 365 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 366 continue; 367 368 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 369 hoistReadWrite(matchingRead, write, it.value()); 370 changed = true; 371 forOp.erase(); 372 373 // Need to interrupt and restart: erasing the loop messes up the walk. 374 return WalkResult::interrupt(); 375 } 376 return WalkResult::advance(); 377 }); 378 // Apply canonicalization so the newForOp + yield folds immediately, thus 379 // cleaning up the IR and potentially enabling more hoisting. 380 if (changed) { 381 RewritePatternSet patterns(func->getContext()); 382 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 383 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 384 } 385 } 386 } 387 388 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 389 bool changed = true; 390 while (changed) { 391 changed = false; 392 393 func.walk([&](vector::TransferReadOp transferRead) { 394 if (!transferRead.getShapedType().isa<MemRefType>()) 395 return WalkResult::advance(); 396 397 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 398 << *transferRead.getOperation() << "\n"); 399 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 400 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 401 << "\n"); 402 if (!loop) 403 return WalkResult::advance(); 404 405 if (failed(moveLoopInvariantCode( 406 cast<LoopLikeOpInterface>(loop.getOperation())))) 407 llvm_unreachable( 408 "Unexpected failure to move invariant code out of loop"); 409 410 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 411 << "\n"); 412 413 llvm::SetVector<Operation *> forwardSlice; 414 getForwardSlice(transferRead.getOperation(), &forwardSlice); 415 416 // Look for the last TransferWriteOp in the forwardSlice of 417 // `transferRead` that operates on the same memref. 418 vector::TransferWriteOp transferWrite; 419 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 420 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 421 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 422 continue; 423 transferWrite = candidateWrite; 424 } 425 426 // All operands of the TransferRead must be defined outside of the loop. 427 for (auto operand : transferRead.getOperands()) 428 if (!loop.isDefinedOutsideOfLoop(operand)) 429 return WalkResult::advance(); 430 431 // Only hoist transfer_read / transfer_write pairs for now. 432 if (!transferWrite) 433 return WalkResult::advance(); 434 435 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 436 << "\n"); 437 438 // Approximate aliasing by checking that: 439 // 1. indices are the same, 440 // 2. no other operations in the loop access the same memref except 441 // for transfer_read/transfer_write accessing statically disjoint 442 // slices. 443 if (transferRead.indices() != transferWrite.indices() && 444 transferRead.getVectorType() == transferWrite.getVectorType()) 445 return WalkResult::advance(); 446 447 // TODO: may want to memoize this information for performance but it 448 // likely gets invalidated often. 449 DominanceInfo dom(loop); 450 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 451 return WalkResult::advance(); 452 for (auto &use : transferRead.source().getUses()) { 453 if (!dom.properlyDominates(loop, use.getOwner())) 454 continue; 455 if (use.getOwner() == transferRead.getOperation() || 456 use.getOwner() == transferWrite.getOperation()) 457 continue; 458 if (auto transferWriteUse = 459 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 460 if (!isDisjointTransferSet( 461 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 462 cast<VectorTransferOpInterface>( 463 transferWriteUse.getOperation()))) 464 return WalkResult::advance(); 465 } else if (auto transferReadUse = 466 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 467 if (!isDisjointTransferSet( 468 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 469 cast<VectorTransferOpInterface>( 470 transferReadUse.getOperation()))) 471 return WalkResult::advance(); 472 } else { 473 // Unknown use, we cannot prove that it doesn't alias with the 474 // transferRead/transferWrite operations. 475 return WalkResult::advance(); 476 } 477 } 478 479 // Hoist read before. 480 if (failed(loop.moveOutOfLoop({transferRead}))) 481 llvm_unreachable( 482 "Unexpected failure to move transfer read out of loop"); 483 484 // Hoist write after. 485 transferWrite->moveAfter(loop); 486 487 // Rewrite `loop` with new yields by cloning and erase the original loop. 488 OpBuilder b(transferRead); 489 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 490 transferWrite.vector()); 491 492 // Transfer write has been hoisted, need to update the written value to 493 // the value yielded by the newForOp. 494 transferWrite.vector().replaceAllUsesWith( 495 newForOp.getResults().take_back()[0]); 496 497 changed = true; 498 loop.erase(); 499 // Need to interrupt and restart because erasing the loop messes up the 500 // walk. 501 return WalkResult::interrupt(); 502 }); 503 } 504 } 505 506 /// Ensure prerequisites that guarantee pad op hoisting can occur. 507 /// Return failure in the cases when we cannot perform hoisting; i.e. if either: 508 /// 1. There exists a use of `padTensorOp` that is not a linalg input operand. 509 /// 2. There isn't an enclosing `outermostEnclosingForOp` loop. 510 /// 3. There exists an op with a region that is dominated by 511 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a 512 /// LinalgOp. 513 /// 3. There exists an op with side effects that is dominated by 514 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface. 515 /// 516 /// While ensuring prerequisites: 517 /// 1. Fill the `backwardSlice` to contain the topologically sorted ops 518 /// dominated by `outermostEnclosingForOp`. 519 /// 2. Fill the `packingLoops` to contain only the enclosing loops of 520 /// `backwardSlice` whose IV is actually used in computing padding. Loops that 521 /// remain in `backwardSlice` but that are not in `packingLoops` are 522 /// dimensions of reuse. 523 static LogicalResult 524 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels, 525 llvm::SetVector<Operation *> &backwardSlice, 526 llvm::SetVector<Operation *> &packingLoops) { 527 // Bail on any use that isn't an input of a Linalg op. 528 // Hoisting of inplace updates happens after vectorization. 529 for (OpOperand &use : padTensorOp.result().getUses()) { 530 auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner()); 531 if (!linalgUser || !linalgUser.isInputTensor(&use)) 532 return failure(); 533 } 534 535 // Get at most nLevels of enclosing loops. 536 SmallVector<LoopLikeOpInterface> reverseEnclosingLoops; 537 Operation *outermostEnclosingForOp = nullptr, 538 *nextEnclosingForOp = 539 padTensorOp->getParentOfType<LoopLikeOpInterface>(); 540 while (nLevels-- > 0 && nextEnclosingForOp) { 541 outermostEnclosingForOp = nextEnclosingForOp; 542 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 543 nextEnclosingForOp = 544 nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>(); 545 } 546 if (!outermostEnclosingForOp) 547 return failure(); 548 549 // Get the backwards slice from `padTensorOp` that is dominated by the 550 // outermost enclosing loop. 551 DominanceInfo domInfo(outermostEnclosingForOp); 552 getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, 553 [&](Operation *op) { 554 return domInfo.dominates(outermostEnclosingForOp, op); 555 }); 556 557 // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp. 558 if (llvm::any_of(backwardSlice, [](Operation *op) { 559 return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) && 560 !isa<LinalgOp>(op); 561 })) 562 return failure(); 563 564 // Filter out the loops whose induction variable is not used to compute the 565 // padded result. As a first approximation, just look for IVs that have no use 566 // in the backwardSlice. 567 // These are the dimensions of reuse that we can exploit to reduce the amount 568 // of work / memory. 569 // TODO: would this optimization compose better as a canonicalization? 570 for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) { 571 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); 572 if (!forOp) 573 continue; 574 for (Operation *user : forOp.getInductionVar().getUsers()) { 575 if (backwardSlice.contains(user)) { 576 packingLoops.insert(forOp); 577 break; 578 } 579 } 580 } 581 582 // Backward slice is a topologically sorted list of ops starting at 583 // `outermostEnclosingForOp`. 584 assert(outermostEnclosingForOp == backwardSlice.front()); 585 586 return success(); 587 } 588 589 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step). 590 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp forOp) { 591 MLIRContext *ctx = forOp->getContext(); 592 AffineExpr lb, ub, step; 593 bindDims(ctx, lb, ub); 594 bindSymbols(ctx, step); 595 return b.create<AffineApplyOp>( 596 forOp->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx), 597 ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()}); 598 } 599 600 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). 601 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp forOp) { 602 MLIRContext *ctx = forOp->getContext(); 603 AffineExpr iv, lb, step; 604 bindDims(ctx, iv, lb); 605 bindSymbols(ctx, step); 606 return b.create<AffineApplyOp>( 607 forOp->getLoc(), AffineMap::get(2, 1, {(iv - lb).ceilDiv(step)}, ctx), 608 ValueRange{forOp.getInductionVar(), forOp.lowerBound(), forOp.step()}); 609 } 610 611 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, 612 unsigned nLoops) { 613 llvm::SetVector<Operation *> backwardSlice, packingLoops; 614 if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops, 615 backwardSlice, packingLoops))) 616 return failure(); 617 618 // Update actual number of loops, which may be smaller. 619 nLoops = packingLoops.size(); 620 621 Location loc = padTensorOp->getLoc(); 622 RankedTensorType paddedTensorType = padTensorOp.getResultType(); 623 unsigned paddedRank = paddedTensorType.getRank(); 624 625 // Backward slice is a topologically sorted list of ops starting at 626 // `outermostEnclosingForOp`. 627 Operation *outermostEnclosingForOp = backwardSlice.front(); 628 // IP just before the outermost loop considered that we hoist above. 629 OpBuilder b(outermostEnclosingForOp); 630 631 // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize 632 // padding. 633 SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize); 634 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 635 // tensor. 636 llvm::append_range(packedShape, paddedTensorType.getShape()); 637 auto packedTensorType = 638 RankedTensorType::get(packedShape, paddedTensorType.getElementType()); 639 auto dynamicSizes = 640 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) { 641 return buildLoopTripCount(b, cast<scf::ForOp>(op)); 642 })); 643 Value packedTensor = b.create<linalg::InitTensorOp>( 644 loc, dynamicSizes, packedTensorType.getShape(), 645 packedTensorType.getElementType()); 646 647 // Clone the operations involved in the backward slice, iteratively stepping 648 // into the loops that we encounter. 649 // The implementation proceeds in a stack-like fashion: 650 // 1. Iteratively clone and step into the loops, pushing the `packedTensor` 651 // deeper in the stack. 652 // 2. Create a SubTensorInsert at the top of the stack. 653 // 3. Iteratively pop and yield the result of the SubTensorInsertOp across 654 // the cloned loops. 655 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; 656 clonedLoopIvs.reserve(nLoops); 657 leadingPackedTensorIndexings.reserve(nLoops); 658 BlockAndValueMapping bvm; 659 // Stack step 1. iteratively clone loops and push `packedTensor`. 660 // Insert `padTensorOp` into the backwardSlice so we clone it too. 661 backwardSlice.insert(padTensorOp); 662 for (Operation *op : backwardSlice) { 663 if (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op)) { 664 b.clone(*op, bvm); 665 continue; 666 } 667 // TODO: support more cases as they appear. 668 auto forOp = dyn_cast<scf::ForOp>(op); 669 assert(forOp && "Expected scf::ForOp when hoisting pad ops"); 670 // Unused loop, just skip it. 671 if (!packingLoops.contains(forOp)) 672 continue; 673 auto clonedForOp = 674 b.create<scf::ForOp>(loc, forOp.lowerBound(), forOp.upperBound(), 675 forOp.step(), packedTensor); 676 assert(clonedForOp->getNumRegions() == 1); 677 clonedLoopIvs.push_back(clonedForOp.getInductionVar()); 678 b.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); 679 leadingPackedTensorIndexings.push_back( 680 buildLoopIterationCount(b, clonedForOp)); 681 bvm.map(forOp.getInductionVar(), clonedLoopIvs.back()); 682 packedTensor = clonedForOp.getRegionIterArgs().front(); 683 } 684 685 // Stack step 2. create SubTensorInsertOp at the top of the stack. 686 // offsets = [clonedLoopIvs, 0 .. 0]. 687 SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(), 688 leadingPackedTensorIndexings.end()); 689 offsets.append(paddedRank, b.getIndexAttr(0)); 690 // sizes = [1 .. 1, paddedShape]. 691 SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1)); 692 for (int64_t sz : paddedTensorType.getShape()) { 693 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 694 // tensor. 695 assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); 696 sizes.push_back(b.getIndexAttr(sz)); 697 } 698 // strides = [1 .. 1]. 699 SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1)); 700 701 Value inserted = 702 b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()), 703 packedTensor, offsets, sizes, strides); 704 705 // Stack step 3. iteratively pop the stack and propagate the yield. 706 Value valueToYield = inserted; 707 for (Value iv : llvm::reverse(clonedLoopIvs)) { 708 auto forOp = scf::getForInductionVarOwner(iv); 709 b.setInsertionPointToEnd(&forOp.getRegion().front()); 710 b.create<scf::YieldOp>(loc, valueToYield); 711 valueToYield = forOp.getResult(0); 712 } 713 714 // Now the packed tensor is ready, replace the original padding op by a 715 // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. 716 b.setInsertionPoint(padTensorOp); 717 SmallVector<Value> loopIterationCounts = 718 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { 719 return buildLoopIterationCount(b, cast<scf::ForOp>(loop)); 720 })); 721 // offsets = [originalLoopIvs, 0 .. 0]. 722 offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); 723 offsets.append(paddedRank, b.getIndexAttr(0)); 724 // sizes = [1 .. 1, paddedShape] (definedabove). 725 // strides = [1 .. 1] (defined above) 726 packedTensor = 727 scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); 728 padTensorOp.replaceAllUsesWith( 729 b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor, 730 offsets, sizes, strides) 731 ->getResult(0)); 732 733 Operation *toErase = padTensorOp; 734 735 // Make the newly cloned `padTensorOp` available to the caller. 736 padTensorOp = 737 cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp()); 738 739 toErase->erase(); 740 741 return success(); 742 } 743