1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Dialect/Vector/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 #include "mlir/Transforms/LoopUtils.h" 26 #include "llvm/ADT/StringRef.h" 27 #include "llvm/Support/Debug.h" 28 29 using llvm::dbgs; 30 31 #define DEBUG_TYPE "linalg-hoisting" 32 33 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 34 35 using namespace mlir; 36 using namespace mlir::linalg; 37 38 void mlir::linalg::hoistViewAllocOps(FuncOp func) { 39 bool changed = true; 40 while (changed) { 41 changed = false; 42 func.walk([&changed](Operation *op) { 43 if (!isa<AllocOp, AllocaOp, DeallocOp>(op)) 44 return; 45 46 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n"); 47 auto loop = dyn_cast<scf::ForOp>(op->getParentOp()); 48 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n"); 49 50 // Only hoist out of immediately enclosing scf::ForOp. 51 if (!loop) 52 return; 53 54 // If any operand is defined inside the loop don't hoist. 55 if (llvm::any_of(op->getOperands(), [&](Value v) { 56 return !loop.isDefinedOutsideOfLoop(v); 57 })) 58 return; 59 60 LLVM_DEBUG(DBGS() << "All operands defined outside \n"); 61 62 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist. 63 Value v; 64 if (op->getNumResults() > 0) { 65 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc"); 66 v = op->getResult(0); 67 } 68 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) { 69 return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner()); 70 })) { 71 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n"); 72 return; 73 } 74 75 // Move AllocOp before the loop. 76 if (isa<AllocOp, AllocaOp>(op)) 77 (void)loop.moveOutOfLoop({op}); 78 else // Move DeallocOp outside of the loop. 79 op->moveAfter(loop); 80 changed = true; 81 }); 82 } 83 } 84 85 namespace { 86 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 87 /// instructions that need to be hoisted too. 88 struct HoistableWrite { 89 vector::TransferWriteOp transferWriteOp; 90 SubTensorInsertOp subTensorInsertOp; 91 }; 92 /// Represents a unit of hoistable TransferReadOp. This may comprise other 93 /// instructions that need to be hoisted too. 94 struct HoistableRead { 95 vector::TransferReadOp transferReadOp; 96 SubTensorOp subTensorOp; 97 }; 98 } // namespace 99 100 /// Return true if op1 and op2 are the same constant or the same SSA value. 101 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 102 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 103 Attribute attr = ofr.dyn_cast<Attribute>(); 104 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 105 if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>()) 106 attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue(); 107 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 108 return intAttr.getValue().getSExtValue(); 109 return llvm::None; 110 }; 111 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 112 if (cst1 && cst2 && *cst1 == *cst2) 113 return true; 114 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 115 return v1 && v2 && v1 == v2; 116 } 117 118 /// Return true is all offsets, sizes and strides are equal. 119 static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) { 120 if (s.static_offsets().size() != si.static_offsets().size()) 121 return false; 122 if (s.static_sizes().size() != si.static_sizes().size()) 123 return false; 124 if (s.static_strides().size() != si.static_strides().size()) 125 return false; 126 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 127 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 128 return false; 129 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 130 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 131 return false; 132 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 133 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 134 return false; 135 return true; 136 } 137 138 /// Look for a HoistableRead, in the given tensor uses, accessing the same 139 /// offset as the HoistableWrite. 140 static HoistableRead findMatchingTransferRead(HoistableWrite write, 141 Value srcTensor) { 142 assert(write.transferWriteOp && 143 "expected hoistable write to have a .transfer_write"); 144 145 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 146 << *write.transferWriteOp.getOperation() << "\n"); 147 if (write.subTensorInsertOp) 148 LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: " 149 << *write.subTensorInsertOp.getOperation() << "\n"); 150 151 for (Operation *user : srcTensor.getUsers()) { 152 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 153 << "\n"); 154 155 // If HoistableWrite involves a SubTensorInsertOp, we need to find a 156 // matching SubTensorOp. 157 SubTensorOp subTensorOp; 158 Operation *maybeTransferReadUser = user; 159 if (write.subTensorInsertOp) { 160 subTensorOp = dyn_cast<SubTensorOp>(user); 161 if (!subTensorOp || subTensorOp.getResult().getType() != 162 write.subTensorInsertOp.source().getType()) 163 continue; 164 165 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 166 << *subTensorOp << " vs " << *write.subTensorInsertOp 167 << "\n"); 168 if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp)) 169 continue; 170 171 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 172 // If we got here, subTensorOp is hoistable iff it has exactly 2 uses: 173 // 1. the transfer_write we want to hoist. 174 // 2. a matching transfer_read. 175 // Anything else, we skip. 176 bool skip = false; 177 Operation *otherUser = nullptr; 178 for (Operation *u : subTensorOp->getUsers()) { 179 if (u == write.transferWriteOp) 180 continue; 181 if (otherUser) { 182 skip = true; 183 break; 184 } 185 otherUser = u; 186 } 187 if (skip || !otherUser) 188 continue; 189 maybeTransferReadUser = otherUser; 190 } 191 192 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 193 << "\n"); 194 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 195 if (read && read.indices() == write.transferWriteOp.indices() && 196 read.getVectorType() == write.transferWriteOp.getVectorType()) 197 return HoistableRead{read, subTensorOp}; 198 } 199 return HoistableRead(); 200 } 201 202 /// Check if the chunk of data inserted by the HoistableWrite are read by any 203 /// other op than the HoistableRead candidate. 204 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 205 HoistableRead candidateRead, 206 BlockArgument tensorArg) { 207 // Make sure none of the other uses read the part of the tensor modified 208 // by the transfer_write. 209 llvm::SmallVector<Value::use_range, 1> uses; 210 uses.push_back(tensorArg.getUses()); 211 while (!uses.empty()) { 212 for (OpOperand &use : uses.pop_back_val()) { 213 Operation *user = use.getOwner(); 214 // Skip the candidate use, only inspect the "other" uses. 215 if (user == candidateRead.transferReadOp || 216 user == candidateRead.subTensorOp || user == write.transferWriteOp || 217 user == write.subTensorInsertOp) 218 continue; 219 // Consider all transitive uses through a subtensor / subtensor_insert. 220 // TODO: atm we just bail because a stronger analysis is needed for these 221 // cases. 222 if (isa<SubTensorOp, SubTensorInsertOp>(user)) 223 return true; 224 // Consider all transitive uses through a vector.transfer_write. 225 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 226 uses.push_back(writeUser->getResult(0).getUses()); 227 continue; 228 } 229 // Consider all nested uses through an scf::ForOp. We may have 230 // pass-through tensor arguments left from previous level of 231 // hoisting. 232 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 233 Value arg = forUser.getLoopBody().getArgument( 234 use.getOperandNumber() - forUser.getNumControlOperands() + 235 /*iv value*/ 1); 236 uses.push_back(arg.getUses()); 237 continue; 238 } 239 // Follow the use yield as long as it doesn't escape the original 240 // region. 241 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 242 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 243 yieldUser->getParentOp())) { 244 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 245 uses.push_back(ret.getUses()); 246 continue; 247 } 248 auto read = dyn_cast<vector::TransferReadOp>(user); 249 if (!read || !isDisjointTransferIndices( 250 cast<VectorTransferOpInterface>(read.getOperation()), 251 cast<VectorTransferOpInterface>( 252 write.transferWriteOp.getOperation()))) { 253 return true; 254 } 255 } 256 } 257 return false; 258 } 259 260 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 261 /// Return the null HoistableWrite() if it is not comprised of a 262 /// vector.transfer_write + optional subtensor_insert or if any of the indexings 263 /// is `forOp`-dependent. 264 static HoistableWrite 265 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 266 OpOperand &yieldOperand) { 267 Value v = yieldOperand.get(); 268 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 269 // Indexing must not depend on `forOp`. 270 for (Value operand : write.indices()) 271 if (!forOp.isDefinedOutsideOfLoop(operand)) 272 return HoistableWrite(); 273 274 return HoistableWrite{write, nullptr}; 275 } 276 277 if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) { 278 // Inserted subTensor must come from vector.transfer_write. 279 auto write = 280 subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>(); 281 if (!write) 282 return HoistableWrite(); 283 284 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 285 auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>(); 286 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 287 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 288 return HoistableWrite(); 289 290 // Indexing inserted into must not depend on `forOp`. 291 for (Value operand : subTensorInsertOp->getOperands().drop_front( 292 SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex())) 293 if (!forOp.isDefinedOutsideOfLoop(operand)) 294 return HoistableWrite(); 295 296 return HoistableWrite{write, subTensorInsertOp}; 297 } 298 299 return HoistableWrite(); 300 } 301 302 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 303 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 304 BlockArgument tensorBBArg) { 305 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 306 assert(read.transferReadOp && write.transferWriteOp && 307 "expected transfer_read and transfer_write ops to be set"); 308 assert(((read.subTensorOp && write.subTensorInsertOp) || 309 (!read.subTensorOp && !write.subTensorInsertOp)) && 310 "expected matching subtensor / subtensor_insert"); 311 LLVM_DEBUG(DBGS() << "In forOp:\n" 312 << *forOp.getOperation() 313 << "\nHoist: " << *read.transferReadOp.getOperation() 314 << "\nHoist: " << *write.transferWriteOp.getOperation() 315 << "\nInvolving: " << tensorBBArg << "\n"); 316 317 // If a read subtensor is present, hoist it. 318 if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp}))) 319 llvm_unreachable("Unexpected failure moving subtensor out of loop"); 320 321 // Hoist the transfer_read op. 322 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 323 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 324 325 // TODO: don't hardcode /*numIvs=*/1. 326 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 327 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 328 329 // Update the source tensor. 330 if (read.subTensorOp) 331 read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 332 else 333 read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 334 335 // Hoist write after. 336 if (write.subTensorInsertOp) 337 write.subTensorInsertOp->moveAfter(forOp); 338 write.transferWriteOp->moveAfter(forOp); 339 340 // Update the yield. 341 auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator()); 342 if (write.subTensorInsertOp) 343 yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest()); 344 else 345 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 346 347 // Rewrite `loop` with additional new yields. 348 OpBuilder b(read.transferReadOp); 349 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 350 write.transferWriteOp.vector()); 351 // Transfer write has been hoisted, need to update the vector and tensor 352 // source. Replace the result of the loop to use the new tensor created 353 // outside the loop. 354 // Depending on whether a subtensor_insert is present or not, it carries the 355 // update on the tensor operands. 356 if (write.subTensorInsertOp) { 357 newForOp.getResult(initArgNumber) 358 .replaceAllUsesWith(write.subTensorInsertOp.getResult()); 359 write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result()); 360 write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source()); 361 } else { 362 newForOp.getResult(initArgNumber) 363 .replaceAllUsesWith(write.transferWriteOp.getResult(0)); 364 write.transferWriteOp.sourceMutable().assign( 365 newForOp.getResult(initArgNumber)); 366 } 367 368 // Always update with the newly yield tensor and vector. 369 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 370 } 371 372 // To hoist transfer op on tensor the logic can be significantly simplified 373 // compared to the case on buffer. The transformation follows this logic: 374 // 1. Look for transfer_write with a single use from ForOp yield 375 // 2. Check the uses of the matching block argument and look for a transfer_read 376 // with the same indices. 377 // 3. Check that all the other uses of the tensor argument are either disjoint 378 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 379 // the new tensor has the same restrictions on its uses. 380 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 381 // After this transformation the scf.forOp may have unused arguments that can be 382 // remove by the canonicalization pass. 383 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 384 bool changed = true; 385 while (changed) { 386 changed = false; 387 func.walk([&](scf::ForOp forOp) { 388 Operation *yield = forOp.getBody()->getTerminator(); 389 for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { 390 OpOperand &ret = yield->getOpOperand(it.index()); 391 HoistableWrite write = 392 getLoopInvariantTransferWriteOpDefining(forOp, ret); 393 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 394 continue; 395 LLVM_DEBUG(dbgs() << "\n"; 396 DBGS() << "Candidate write for hoisting: " 397 << *write.transferWriteOp.getOperation() << "\n"); 398 if (write.subTensorInsertOp) 399 LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: " 400 << *write.subTensorInsertOp.getOperation() << "\n"); 401 if (llvm::any_of(write.transferWriteOp.indices(), 402 [&forOp](Value index) { 403 return !forOp.isDefinedOutsideOfLoop(index); 404 })) 405 continue; 406 // Find a read with the same type and indices. 407 HoistableRead matchingRead = 408 findMatchingTransferRead(write, it.value()); 409 // Make sure none of the other uses read the part of the tensor modified 410 // by the transfer_write. 411 if (!matchingRead.transferReadOp || 412 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 413 continue; 414 415 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 416 hoistReadWrite(matchingRead, write, it.value()); 417 changed = true; 418 forOp.erase(); 419 420 // Need to interrupt and restart: erasing the loop messes up the walk. 421 return WalkResult::interrupt(); 422 } 423 return WalkResult::advance(); 424 }); 425 // Apply canonicalization so the newForOp + yield folds immediately, thus 426 // cleaning up the IR and potentially enabling more hoisting. 427 if (changed) { 428 OwningRewritePatternList patterns; 429 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 430 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 431 } 432 } 433 } 434 435 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 436 bool changed = true; 437 while (changed) { 438 changed = false; 439 440 func.walk([&](vector::TransferReadOp transferRead) { 441 if (!transferRead.getShapedType().isa<MemRefType>()) 442 return WalkResult::advance(); 443 444 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 445 << *transferRead.getOperation() << "\n"); 446 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 447 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 448 << "\n"); 449 if (!loop) 450 return WalkResult::advance(); 451 452 if (failed(moveLoopInvariantCode( 453 cast<LoopLikeOpInterface>(loop.getOperation())))) 454 llvm_unreachable( 455 "Unexpected failure to move invariant code out of loop"); 456 457 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 458 << "\n"); 459 460 llvm::SetVector<Operation *> forwardSlice; 461 getForwardSlice(transferRead.getOperation(), &forwardSlice); 462 463 // Look for the last TransferWriteOp in the forwardSlice of 464 // `transferRead` that operates on the same memref. 465 vector::TransferWriteOp transferWrite; 466 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 467 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 468 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 469 continue; 470 transferWrite = candidateWrite; 471 } 472 473 // All operands of the TransferRead must be defined outside of the loop. 474 for (auto operand : transferRead.getOperands()) 475 if (!loop.isDefinedOutsideOfLoop(operand)) 476 return WalkResult::advance(); 477 478 // Only hoist transfer_read / transfer_write pairs for now. 479 if (!transferWrite) 480 return WalkResult::advance(); 481 482 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 483 << "\n"); 484 485 // Approximate aliasing by checking that: 486 // 1. indices are the same, 487 // 2. no other operations in the loop access the same memref except 488 // for transfer_read/transfer_write accessing statically disjoint 489 // slices. 490 if (transferRead.indices() != transferWrite.indices() && 491 transferRead.getVectorType() == transferWrite.getVectorType()) 492 return WalkResult::advance(); 493 494 // TODO: may want to memoize this information for performance but it 495 // likely gets invalidated often. 496 DominanceInfo dom(loop); 497 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 498 return WalkResult::advance(); 499 for (auto &use : transferRead.source().getUses()) { 500 if (!dom.properlyDominates(loop, use.getOwner())) 501 continue; 502 if (use.getOwner() == transferRead.getOperation() || 503 use.getOwner() == transferWrite.getOperation()) 504 continue; 505 if (auto transferWriteUse = 506 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 507 if (!isDisjointTransferSet( 508 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 509 cast<VectorTransferOpInterface>( 510 transferWriteUse.getOperation()))) 511 return WalkResult::advance(); 512 } else if (auto transferReadUse = 513 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 514 if (!isDisjointTransferSet( 515 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 516 cast<VectorTransferOpInterface>( 517 transferReadUse.getOperation()))) 518 return WalkResult::advance(); 519 } else { 520 // Unknown use, we cannot prove that it doesn't alias with the 521 // transferRead/transferWrite operations. 522 return WalkResult::advance(); 523 } 524 } 525 526 // Hoist read before. 527 if (failed(loop.moveOutOfLoop({transferRead}))) 528 llvm_unreachable( 529 "Unexpected failure to move transfer read out of loop"); 530 531 // Hoist write after. 532 transferWrite->moveAfter(loop); 533 534 // Rewrite `loop` with new yields by cloning and erase the original loop. 535 OpBuilder b(transferRead); 536 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 537 transferWrite.vector()); 538 539 // Transfer write has been hoisted, need to update the written value to 540 // the value yielded by the newForOp. 541 transferWrite.vector().replaceAllUsesWith( 542 newForOp.getResults().take_back()[0]); 543 544 changed = true; 545 loop.erase(); 546 // Need to interrupt and restart because erasing the loop messes up the 547 // walk. 548 return WalkResult::interrupt(); 549 }); 550 } 551 } 552 553 /// Ensure prerequisites that guarantee pad op hoisting can occur. 554 /// Return failure in the cases when we cannot perform hoisting; i.e. if either: 555 /// 1. There exists a use of `padTensorOp` that is not a linalg input operand. 556 /// 2. There isn't an enclosing `outermostEnclosingForOp` loop. 557 /// 3. There exists an op with a region that is dominated by 558 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a 559 /// LinalgOp. 560 /// 3. There exists an op with side effects that is dominated by 561 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface. 562 /// 563 /// While ensuring prerequisites: 564 /// 1. Fill the `backwardSlice` to contain the topologically sorted ops 565 /// dominated by `outermostEnclosingForOp`. 566 /// 2. Fill the `packingLoops` to contain only the enclosing loops of 567 /// `backwardSlice` whose IV is actually used in computing padding. Loops that 568 /// remain in `backwardSlice` but that are not in `packingLoops` are 569 /// dimensions of reuse. 570 static LogicalResult 571 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels, 572 llvm::SetVector<Operation *> &backwardSlice, 573 llvm::SetVector<Operation *> &packingLoops) { 574 // Bail on any use that isn't an input of a Linalg op. 575 // Hoisting of inplace updates happens after vectorization. 576 for (OpOperand &use : padTensorOp.result().getUses()) { 577 auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner()); 578 if (!linalgUser || !linalgUser.isInputTensor(&use)) 579 return failure(); 580 } 581 582 // Get at most nLevels of enclosing loops. 583 SmallVector<LoopLikeOpInterface> reverseEnclosingLoops; 584 Operation *outermostEnclosingForOp = nullptr, 585 *nextEnclosingForOp = 586 padTensorOp->getParentOfType<LoopLikeOpInterface>(); 587 while (nLevels-- > 0 && nextEnclosingForOp) { 588 outermostEnclosingForOp = nextEnclosingForOp; 589 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 590 nextEnclosingForOp = 591 nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>(); 592 } 593 if (!outermostEnclosingForOp) 594 return failure(); 595 596 // Get the backwards slice from `padTensorOp` that is dominated by the 597 // outermost enclosing loop. 598 DominanceInfo domInfo(outermostEnclosingForOp); 599 getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, 600 [&](Operation *op) { 601 return domInfo.dominates(outermostEnclosingForOp, op); 602 }); 603 604 // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp. 605 if (llvm::any_of(backwardSlice, [](Operation *op) { 606 return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) && 607 !isa<LinalgOp>(op); 608 })) 609 return failure(); 610 611 // Filter out the loops whose induction variable is not used to compute the 612 // padded result. As a first approximation, just look for IVs that have no use 613 // in the backwardSlice. 614 // These are the dimensions of reuse that we can exploit to reduce the amount 615 // of work / memory. 616 // TODO: would this optimization compose better as a canonicalization? 617 for (LoopLikeOpInterface loop : reverseEnclosingLoops) { 618 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); 619 if (!forOp) 620 continue; 621 for (Operation *user : forOp.getInductionVar().getUsers()) { 622 if (backwardSlice.contains(user)) { 623 packingLoops.insert(forOp); 624 break; 625 } 626 } 627 } 628 629 // Backward slice is a topologically sorted list of ops starting at 630 // `outermostEnclosingForOp`. 631 assert(outermostEnclosingForOp == backwardSlice.front()); 632 633 return success(); 634 } 635 636 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step). 637 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp forOp) { 638 MLIRContext *ctx = forOp->getContext(); 639 AffineExpr lb, ub, step; 640 bindDims(ctx, lb, ub); 641 bindSymbols(ctx, step); 642 return b.create<AffineApplyOp>( 643 forOp->getLoc(), AffineMap::get(2, 1, {(ub - lb).ceilDiv(step)}, ctx), 644 ValueRange{forOp.lowerBound(), forOp.upperBound(), forOp.step()}); 645 } 646 647 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). 648 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp forOp) { 649 MLIRContext *ctx = forOp->getContext(); 650 AffineExpr iv, lb, step; 651 bindDims(ctx, iv, lb); 652 bindSymbols(ctx, step); 653 return b.create<AffineApplyOp>( 654 forOp->getLoc(), AffineMap::get(2, 1, {(iv - lb).ceilDiv(step)}, ctx), 655 ValueRange{forOp.getInductionVar(), forOp.lowerBound(), forOp.step()}); 656 } 657 658 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, 659 unsigned nLoops) { 660 llvm::SetVector<Operation *> backwardSlice, packingLoops; 661 if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops, 662 backwardSlice, packingLoops))) 663 return failure(); 664 665 // Update actual number of loops, which may be smaller. 666 nLoops = packingLoops.size(); 667 668 Location loc = padTensorOp->getLoc(); 669 RankedTensorType paddedTensorType = padTensorOp.getResultType(); 670 unsigned paddedRank = paddedTensorType.getRank(); 671 672 // Backward slice is a topologically sorted list of ops starting at 673 // `outermostEnclosingForOp`. 674 Operation *outermostEnclosingForOp = backwardSlice.front(); 675 // IP just before the outermost loop considered that we hoist above. 676 OpBuilder b(outermostEnclosingForOp); 677 678 // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize 679 // padding. 680 SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize); 681 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 682 // tensor. 683 llvm::append_range(packedShape, paddedTensorType.getShape()); 684 auto packedTensorType = 685 RankedTensorType::get(packedShape, paddedTensorType.getElementType()); 686 auto dynamicSizes = 687 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) { 688 return buildLoopTripCount(b, cast<scf::ForOp>(op)); 689 })); 690 Value packedTensor = b.create<linalg::InitTensorOp>( 691 loc, dynamicSizes, packedTensorType.getShape(), 692 packedTensorType.getElementType()); 693 694 // Clone the operations involved in the backward slice, iteratively stepping 695 // into the loops that we encounter. 696 // The implementation proceeds in a stack-like fashion: 697 // 1. Iteratively clone and step into the loops, pushing the `packedTensor` 698 // deeper in the stack. 699 // 2. Create a SubTensorInsert at the top of the stack. 700 // 3. Iteratively pop and yield the result of the SubTensorInsertOp across 701 // the cloned loops. 702 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; 703 clonedLoopIvs.reserve(nLoops); 704 leadingPackedTensorIndexings.reserve(nLoops); 705 BlockAndValueMapping bvm; 706 // Stack step 1. iteratively clone loops and push `packedTensor`. 707 // Insert `padTensorOp` into the backwardSlice so we clone it too. 708 backwardSlice.insert(padTensorOp); 709 for (Operation *op : backwardSlice) { 710 if (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op)) { 711 b.clone(*op, bvm); 712 continue; 713 } 714 // TODO: support more cases as they appear. 715 auto forOp = dyn_cast<scf::ForOp>(op); 716 assert(forOp && "Expected scf::ForOp when hoisting pad ops"); 717 // Unused loop, just skip it. 718 if (!packingLoops.contains(forOp)) 719 continue; 720 auto clonedForOp = 721 b.create<scf::ForOp>(loc, forOp.lowerBound(), forOp.upperBound(), 722 forOp.step(), packedTensor); 723 assert(clonedForOp->getNumRegions() == 1); 724 clonedLoopIvs.push_back(clonedForOp.getInductionVar()); 725 b.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); 726 leadingPackedTensorIndexings.push_back( 727 buildLoopIterationCount(b, clonedForOp)); 728 bvm.map(forOp.getInductionVar(), clonedLoopIvs.back()); 729 packedTensor = clonedForOp.getRegionIterArgs().front(); 730 } 731 732 // Stack step 2. create SubTensorInsertOp at the top of the stack. 733 // offsets = [clonedLoopIvs, 0 .. 0]. 734 SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(), 735 leadingPackedTensorIndexings.end()); 736 offsets.append(paddedRank, b.getIndexAttr(0)); 737 // sizes = [1 .. 1, paddedShape]. 738 SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1)); 739 for (int64_t sz : paddedTensorType.getShape()) { 740 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 741 // tensor. 742 assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); 743 sizes.push_back(b.getIndexAttr(sz)); 744 } 745 // strides = [1 .. 1]. 746 SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1)); 747 748 Value inserted = 749 b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()), 750 packedTensor, offsets, sizes, strides); 751 752 // Stack step 3. iteratively pop the stack and propagate the yield. 753 Value valueToYield = inserted; 754 for (Value iv : llvm::reverse(clonedLoopIvs)) { 755 auto forOp = scf::getForInductionVarOwner(iv); 756 b.setInsertionPointToEnd(&forOp.getRegion().front()); 757 b.create<scf::YieldOp>(loc, valueToYield); 758 valueToYield = forOp.getResult(0); 759 } 760 761 // Now the packed tensor is ready, replace the original padding op by a 762 // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. 763 b.setInsertionPoint(padTensorOp); 764 SmallVector<Value> loopIterationCounts = 765 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { 766 return buildLoopIterationCount(b, cast<scf::ForOp>(loop)); 767 })); 768 // offsets = [originalLoopIvs, 0 .. 0]. 769 offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); 770 offsets.append(paddedRank, b.getIndexAttr(0)); 771 // sizes = [1 .. 1, paddedShape] (definedabove). 772 // strides = [1 .. 1] (defined above) 773 packedTensor = 774 scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); 775 padTensorOp.replaceAllUsesWith( 776 b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor, 777 offsets, sizes, strides) 778 ->getResult(0)); 779 780 Operation *toErase = padTensorOp; 781 782 // Make the newly cloned `padTensorOp` available to the caller. 783 padTensorOp = 784 cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp()); 785 786 toErase->erase(); 787 788 return success(); 789 } 790