1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Utils.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Dialect/Vector/VectorUtils.h" 24 #include "mlir/IR/BuiltinOps.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "mlir/Transforms/LoopUtils.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "llvm/Support/Debug.h" 30 31 using llvm::dbgs; 32 33 #define DEBUG_TYPE "linalg-hoisting" 34 35 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 36 37 using namespace mlir; 38 using namespace mlir::linalg; 39 40 namespace { 41 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 42 /// instructions that need to be hoisted too. 43 struct HoistableWrite { 44 vector::TransferWriteOp transferWriteOp; 45 SubTensorInsertOp subTensorInsertOp; 46 }; 47 /// Represents a unit of hoistable TransferReadOp. This may comprise other 48 /// instructions that need to be hoisted too. 49 struct HoistableRead { 50 vector::TransferReadOp transferReadOp; 51 SubTensorOp subTensorOp; 52 }; 53 } // namespace 54 55 /// Return true if op1 and op2 are the same constant or the same SSA value. 56 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 57 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 58 Attribute attr = ofr.dyn_cast<Attribute>(); 59 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 60 if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>()) 61 attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue(); 62 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 63 return intAttr.getValue().getSExtValue(); 64 return llvm::None; 65 }; 66 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 67 if (cst1 && cst2 && *cst1 == *cst2) 68 return true; 69 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 70 return v1 && v2 && v1 == v2; 71 } 72 73 /// Return true is all offsets, sizes and strides are equal. 74 static bool sameOffsetsSizesAndStrides(SubTensorOp s, SubTensorInsertOp si) { 75 if (s.static_offsets().size() != si.static_offsets().size()) 76 return false; 77 if (s.static_sizes().size() != si.static_sizes().size()) 78 return false; 79 if (s.static_strides().size() != si.static_strides().size()) 80 return false; 81 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 82 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 83 return false; 84 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 85 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 86 return false; 87 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 88 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 89 return false; 90 return true; 91 } 92 93 /// Look for a HoistableRead, in the given tensor uses, accessing the same 94 /// offset as the HoistableWrite. 95 static HoistableRead findMatchingTransferRead(HoistableWrite write, 96 Value srcTensor) { 97 assert(write.transferWriteOp && 98 "expected hoistable write to have a .transfer_write"); 99 100 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 101 << *write.transferWriteOp.getOperation() << "\n"); 102 if (write.subTensorInsertOp) 103 LLVM_DEBUG(DBGS() << "findMatchingTransferRead subTensorInsertOp: " 104 << *write.subTensorInsertOp.getOperation() << "\n"); 105 106 for (Operation *user : srcTensor.getUsers()) { 107 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 108 << "\n"); 109 110 // If HoistableWrite involves a SubTensorInsertOp, we need to find a 111 // matching SubTensorOp. 112 SubTensorOp subTensorOp; 113 Operation *maybeTransferReadUser = user; 114 if (write.subTensorInsertOp) { 115 subTensorOp = dyn_cast<SubTensorOp>(user); 116 if (!subTensorOp || subTensorOp.getResult().getType() != 117 write.subTensorInsertOp.source().getType()) 118 continue; 119 120 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 121 << *subTensorOp << " vs " << *write.subTensorInsertOp 122 << "\n"); 123 if (!sameOffsetsSizesAndStrides(subTensorOp, write.subTensorInsertOp)) 124 continue; 125 126 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 127 // If we got here, subTensorOp is hoistable iff it has exactly 2 uses: 128 // 1. the transfer_write we want to hoist. 129 // 2. a matching transfer_read. 130 // Anything else, we skip. 131 bool skip = false; 132 Operation *otherUser = nullptr; 133 for (Operation *u : subTensorOp->getUsers()) { 134 if (u == write.transferWriteOp) 135 continue; 136 if (otherUser) { 137 skip = true; 138 break; 139 } 140 otherUser = u; 141 } 142 if (skip || !otherUser) 143 continue; 144 maybeTransferReadUser = otherUser; 145 } 146 147 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 148 << "\n"); 149 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 150 if (read && read.indices() == write.transferWriteOp.indices() && 151 read.getVectorType() == write.transferWriteOp.getVectorType()) 152 return HoistableRead{read, subTensorOp}; 153 } 154 return HoistableRead(); 155 } 156 157 /// Check if the chunk of data inserted by the HoistableWrite are read by any 158 /// other op than the HoistableRead candidate. 159 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 160 HoistableRead candidateRead, 161 BlockArgument tensorArg) { 162 // Make sure none of the other uses read the part of the tensor modified 163 // by the transfer_write. 164 llvm::SmallVector<Value::use_range, 1> uses; 165 uses.push_back(tensorArg.getUses()); 166 while (!uses.empty()) { 167 for (OpOperand &use : uses.pop_back_val()) { 168 Operation *user = use.getOwner(); 169 // Skip the candidate use, only inspect the "other" uses. 170 if (user == candidateRead.transferReadOp || 171 user == candidateRead.subTensorOp || user == write.transferWriteOp || 172 user == write.subTensorInsertOp) 173 continue; 174 // Consider all transitive uses through a subtensor / subtensor_insert. 175 // TODO: atm we just bail because a stronger analysis is needed for these 176 // cases. 177 if (isa<SubTensorOp, SubTensorInsertOp>(user)) 178 return true; 179 // Consider all transitive uses through a vector.transfer_write. 180 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 181 uses.push_back(writeUser->getResult(0).getUses()); 182 continue; 183 } 184 // Consider all nested uses through an scf::ForOp. We may have 185 // pass-through tensor arguments left from previous level of 186 // hoisting. 187 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 188 Value arg = forUser.getLoopBody().getArgument( 189 use.getOperandNumber() - forUser.getNumControlOperands() + 190 /*iv value*/ 1); 191 uses.push_back(arg.getUses()); 192 continue; 193 } 194 // Follow the use yield as long as it doesn't escape the original 195 // region. 196 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 197 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 198 yieldUser->getParentOp())) { 199 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 200 uses.push_back(ret.getUses()); 201 continue; 202 } 203 auto read = dyn_cast<vector::TransferReadOp>(user); 204 if (!read || !isDisjointTransferIndices( 205 cast<VectorTransferOpInterface>(read.getOperation()), 206 cast<VectorTransferOpInterface>( 207 write.transferWriteOp.getOperation()))) { 208 return true; 209 } 210 } 211 } 212 return false; 213 } 214 215 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 216 /// Return the null HoistableWrite() if it is not comprised of a 217 /// vector.transfer_write + optional subtensor_insert or if any of the indexings 218 /// is `forOp`-dependent. 219 static HoistableWrite 220 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 221 OpOperand &yieldOperand) { 222 Value v = yieldOperand.get(); 223 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 224 // Indexing must not depend on `forOp`. 225 for (Value operand : write.indices()) 226 if (!forOp.isDefinedOutsideOfLoop(operand)) 227 return HoistableWrite(); 228 229 return HoistableWrite{write, nullptr}; 230 } 231 232 if (auto subTensorInsertOp = v.getDefiningOp<SubTensorInsertOp>()) { 233 // Inserted subTensor must come from vector.transfer_write. 234 auto write = 235 subTensorInsertOp.source().getDefiningOp<vector::TransferWriteOp>(); 236 if (!write) 237 return HoistableWrite(); 238 239 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 240 auto bbArg = subTensorInsertOp.dest().dyn_cast<BlockArgument>(); 241 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 242 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 243 return HoistableWrite(); 244 245 // Indexing inserted into must not depend on `forOp`. 246 for (Value operand : subTensorInsertOp->getOperands().drop_front( 247 SubTensorInsertOp::getOffsetSizeAndStrideStartOperandIndex())) 248 if (!forOp.isDefinedOutsideOfLoop(operand)) 249 return HoistableWrite(); 250 251 return HoistableWrite{write, subTensorInsertOp}; 252 } 253 254 return HoistableWrite(); 255 } 256 257 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 258 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 259 BlockArgument tensorBBArg) { 260 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 261 assert(read.transferReadOp && write.transferWriteOp && 262 "expected transfer_read and transfer_write ops to be set"); 263 assert(((read.subTensorOp && write.subTensorInsertOp) || 264 (!read.subTensorOp && !write.subTensorInsertOp)) && 265 "expected matching subtensor / subtensor_insert"); 266 LLVM_DEBUG(DBGS() << "In forOp:\n" 267 << *forOp.getOperation() 268 << "\nHoist: " << *read.transferReadOp.getOperation() 269 << "\nHoist: " << *write.transferWriteOp.getOperation() 270 << "\nInvolving: " << tensorBBArg << "\n"); 271 272 // If a read subtensor is present, hoist it. 273 if (read.subTensorOp && failed(forOp.moveOutOfLoop({read.subTensorOp}))) 274 llvm_unreachable("Unexpected failure moving subtensor out of loop"); 275 276 // Hoist the transfer_read op. 277 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 278 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 279 280 // TODO: don't hardcode /*numIvs=*/1. 281 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 282 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 283 284 // Update the source tensor. 285 if (read.subTensorOp) 286 read.subTensorOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 287 else 288 read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 289 290 // Hoist write after. 291 if (write.subTensorInsertOp) 292 write.subTensorInsertOp->moveAfter(forOp); 293 write.transferWriteOp->moveAfter(forOp); 294 295 // Update the yield. 296 auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator()); 297 if (write.subTensorInsertOp) 298 yieldOp->setOperand(initArgNumber, write.subTensorInsertOp.dest()); 299 else 300 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 301 302 // Rewrite `loop` with additional new yields. 303 OpBuilder b(read.transferReadOp); 304 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 305 write.transferWriteOp.vector()); 306 // Transfer write has been hoisted, need to update the vector and tensor 307 // source. Replace the result of the loop to use the new tensor created 308 // outside the loop. 309 // Depending on whether a subtensor_insert is present or not, it carries the 310 // update on the tensor operands. 311 if (write.subTensorInsertOp) { 312 newForOp.getResult(initArgNumber) 313 .replaceAllUsesWith(write.subTensorInsertOp.getResult()); 314 write.transferWriteOp.sourceMutable().assign(read.subTensorOp.result()); 315 write.subTensorInsertOp.destMutable().assign(read.subTensorOp.source()); 316 } else { 317 newForOp.getResult(initArgNumber) 318 .replaceAllUsesWith(write.transferWriteOp.getResult(0)); 319 write.transferWriteOp.sourceMutable().assign( 320 newForOp.getResult(initArgNumber)); 321 } 322 323 // Always update with the newly yield tensor and vector. 324 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 325 } 326 327 // To hoist transfer op on tensor the logic can be significantly simplified 328 // compared to the case on buffer. The transformation follows this logic: 329 // 1. Look for transfer_write with a single use from ForOp yield 330 // 2. Check the uses of the matching block argument and look for a transfer_read 331 // with the same indices. 332 // 3. Check that all the other uses of the tensor argument are either disjoint 333 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 334 // the new tensor has the same restrictions on its uses. 335 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 336 // After this transformation the scf.forOp may have unused arguments that can be 337 // remove by the canonicalization pass. 338 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 339 bool changed = true; 340 while (changed) { 341 changed = false; 342 func.walk([&](scf::ForOp forOp) { 343 Operation *yield = forOp.getBody()->getTerminator(); 344 for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { 345 OpOperand &ret = yield->getOpOperand(it.index()); 346 HoistableWrite write = 347 getLoopInvariantTransferWriteOpDefining(forOp, ret); 348 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 349 continue; 350 LLVM_DEBUG(dbgs() << "\n"; 351 DBGS() << "Candidate write for hoisting: " 352 << *write.transferWriteOp.getOperation() << "\n"); 353 if (write.subTensorInsertOp) 354 LLVM_DEBUG(DBGS() << "Candidate subtensor_insert for hoisting: " 355 << *write.subTensorInsertOp.getOperation() << "\n"); 356 if (llvm::any_of(write.transferWriteOp.indices(), 357 [&forOp](Value index) { 358 return !forOp.isDefinedOutsideOfLoop(index); 359 })) 360 continue; 361 // Find a read with the same type and indices. 362 HoistableRead matchingRead = 363 findMatchingTransferRead(write, it.value()); 364 // Make sure none of the other uses read the part of the tensor modified 365 // by the transfer_write. 366 if (!matchingRead.transferReadOp || 367 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 368 continue; 369 370 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 371 hoistReadWrite(matchingRead, write, it.value()); 372 changed = true; 373 forOp.erase(); 374 375 // Need to interrupt and restart: erasing the loop messes up the walk. 376 return WalkResult::interrupt(); 377 } 378 return WalkResult::advance(); 379 }); 380 // Apply canonicalization so the newForOp + yield folds immediately, thus 381 // cleaning up the IR and potentially enabling more hoisting. 382 if (changed) { 383 RewritePatternSet patterns(func->getContext()); 384 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 385 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 386 } 387 } 388 } 389 390 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 391 bool changed = true; 392 while (changed) { 393 changed = false; 394 395 func.walk([&](vector::TransferReadOp transferRead) { 396 if (!transferRead.getShapedType().isa<MemRefType>()) 397 return WalkResult::advance(); 398 399 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 400 << *transferRead.getOperation() << "\n"); 401 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 402 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 403 << "\n"); 404 if (!loop) 405 return WalkResult::advance(); 406 407 if (failed(moveLoopInvariantCode( 408 cast<LoopLikeOpInterface>(loop.getOperation())))) 409 llvm_unreachable( 410 "Unexpected failure to move invariant code out of loop"); 411 412 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 413 << "\n"); 414 415 SetVector<Operation *> forwardSlice; 416 getForwardSlice(transferRead.getOperation(), &forwardSlice); 417 418 // Look for the last TransferWriteOp in the forwardSlice of 419 // `transferRead` that operates on the same memref. 420 vector::TransferWriteOp transferWrite; 421 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 422 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 423 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 424 continue; 425 transferWrite = candidateWrite; 426 } 427 428 // All operands of the TransferRead must be defined outside of the loop. 429 for (auto operand : transferRead.getOperands()) 430 if (!loop.isDefinedOutsideOfLoop(operand)) 431 return WalkResult::advance(); 432 433 // Only hoist transfer_read / transfer_write pairs for now. 434 if (!transferWrite) 435 return WalkResult::advance(); 436 437 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 438 << "\n"); 439 440 // Approximate aliasing by checking that: 441 // 1. indices are the same, 442 // 2. no other operations in the loop access the same memref except 443 // for transfer_read/transfer_write accessing statically disjoint 444 // slices. 445 if (transferRead.indices() != transferWrite.indices() && 446 transferRead.getVectorType() == transferWrite.getVectorType()) 447 return WalkResult::advance(); 448 449 // TODO: may want to memoize this information for performance but it 450 // likely gets invalidated often. 451 DominanceInfo dom(loop); 452 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 453 return WalkResult::advance(); 454 for (auto &use : transferRead.source().getUses()) { 455 if (!dom.properlyDominates(loop, use.getOwner())) 456 continue; 457 if (use.getOwner() == transferRead.getOperation() || 458 use.getOwner() == transferWrite.getOperation()) 459 continue; 460 if (auto transferWriteUse = 461 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 462 if (!isDisjointTransferSet( 463 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 464 cast<VectorTransferOpInterface>( 465 transferWriteUse.getOperation()))) 466 return WalkResult::advance(); 467 } else if (auto transferReadUse = 468 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 469 if (!isDisjointTransferSet( 470 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 471 cast<VectorTransferOpInterface>( 472 transferReadUse.getOperation()))) 473 return WalkResult::advance(); 474 } else { 475 // Unknown use, we cannot prove that it doesn't alias with the 476 // transferRead/transferWrite operations. 477 return WalkResult::advance(); 478 } 479 } 480 481 // Hoist read before. 482 if (failed(loop.moveOutOfLoop({transferRead}))) 483 llvm_unreachable( 484 "Unexpected failure to move transfer read out of loop"); 485 486 // Hoist write after. 487 transferWrite->moveAfter(loop); 488 489 // Rewrite `loop` with new yields by cloning and erase the original loop. 490 OpBuilder b(transferRead); 491 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 492 transferWrite.vector()); 493 494 // Transfer write has been hoisted, need to update the written value to 495 // the value yielded by the newForOp. 496 transferWrite.vector().replaceAllUsesWith( 497 newForOp.getResults().take_back()[0]); 498 499 changed = true; 500 loop.erase(); 501 // Need to interrupt and restart because erasing the loop messes up the 502 // walk. 503 return WalkResult::interrupt(); 504 }); 505 } 506 } 507 508 /// Return success if `v` is a value that is only transitively defined by ops of 509 /// type in `OpTypeList`. 510 template <typename... OpTypeList> 511 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) { 512 // Compute a backward slice up to, but not including, `outerLimit`. 513 SetVector<Operation *> backwardSlice; 514 getBackwardSlice(v, &backwardSlice, [&](Operation *op) { 515 return outerLimit->isProperAncestor(op); 516 }); 517 // Traverse the backward slice and ensure we can perform the computation to 518 // hoist. 519 for (Operation *op : backwardSlice) { 520 if (isa<OpTypeList...>(op)) 521 continue; 522 LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n"); 523 return false; 524 } 525 return true; 526 } 527 528 bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) { 529 return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>(); 530 } 531 532 /// Compute the tightest lower bound with quantities that are all defined 533 /// outside of `outer`. 534 /// Return null if such a bound cannot be computed. 535 Value computeLoopIndependentLowerBound(OpBuilder &b, scf::ForOp outer, 536 Value v) { 537 if (isDefinedOutsideOrConstant(outer, v)) 538 return v; 539 return Value(); 540 } 541 542 /// Compute the tightest upper bound with quantities that are all defined 543 /// outside of `outer`. 544 /// Expects all ops in the backward slice of `v` up to `outer` to be either 545 /// scf.for, affine.min or affine.apply. 546 static Value computeLoopIndependentUpperBound(OpBuilder &b, scf::ForOp outer, 547 Value v) { 548 if (isDefinedOutsideOrConstant(outer, v)) 549 return v; 550 551 LLVM_DEBUG(DBGS() << "Begin loopIndependentUpperBound for: " << v << "\n"); 552 553 bool ok = 554 backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, AffineApplyOp>( 555 outer, v); 556 assert(ok && "expected to only be defined by scf::ForOp and AffineMinOp"); 557 (void)ok; 558 559 // Compute a backward slice up to, but not including, `outer`. 560 SetVector<Operation *> backwardSlice; 561 getBackwardSlice(v, &backwardSlice, 562 [&](Operation *op) { return outer->isProperAncestor(op); }); 563 backwardSlice.insert(v.getDefiningOp()); 564 565 OpBuilder::InsertionGuard g(b); 566 b.setInsertionPoint(outer); 567 Value res = v; 568 BlockAndValueMapping bvm; 569 for (Operation *op : backwardSlice) { 570 if (isa<scf::ForOp>(op)) 571 continue; 572 if (isa<AffineApplyOp>(op)) { 573 b.clone(*op, bvm); 574 continue; 575 } 576 auto sliceMinOp = cast<AffineMinOp>(op); 577 GetMinMaxExprFn getSCFMinMax = [&](Value value, 578 SmallVectorImpl<Value> &dims, 579 SmallVectorImpl<Value> &symbols) { 580 return getSCFMinMaxExpr(value, dims, symbols, [&](Operation *op) { 581 return outer->isAncestor(op); 582 }); 583 }; 584 // Perform the substitution of the operands of AffineMinOp. 585 auto mapAndOperands = substituteMin(sliceMinOp, getSCFMinMax); 586 SmallVector<Value> resultOperands = mapAndOperands.dims; 587 llvm::append_range(resultOperands, mapAndOperands.symbols); 588 AffineMap map = mapAndOperands.map; 589 canonicalizeMapAndOperands(&map, &resultOperands); 590 map = simplifyAffineMap(map); 591 res = b.create<AffineMinOp>( 592 outer->getLoc(), map, 593 llvm::to_vector<4>(llvm::map_range(resultOperands, [&](Value operand) { 594 return bvm.lookupOrDefault(operand); 595 }))); 596 bvm.map(sliceMinOp, res); 597 } 598 LLVM_DEBUG(DBGS() << "End loopIndependentUpperBound with: " << res << "\n"); 599 return res; 600 } 601 602 /// Return the number of iterations in the loop (ub - lb).ceilDiv(step). 603 /// The returned Value is guaranteed not to depend on any loop comprised in 604 /// [`outer`, `forOp`]. 605 /// Return null if such a loop-independent quantity cannot be computed. 606 static Value buildLoopTripCount(OpBuilder &b, scf::ForOp outer, 607 scf::ForOp forOp) { 608 MLIRContext *ctx = forOp->getContext(); 609 AffineExpr lb, ub, step; 610 bindDims(ctx, lb, ub); 611 bindSymbols(ctx, step); 612 Value lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()), 613 ubVal = computeLoopIndependentUpperBound(b, outer, forOp.upperBound()), 614 stepVal = forOp.step(); 615 if (!lbVal || !ubVal || !stepVal) 616 return Value(); 617 auto loc = forOp->getLoc(); 618 Value res = b.create<AffineApplyOp>(loc, (ub - lb).ceilDiv(step), 619 ValueRange{lbVal, ubVal, stepVal}); 620 return res; 621 } 622 623 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). 624 /// The returned Value is guaranteed not to depend on any loop comprised in 625 /// [`outer`, `forOp`]. 626 /// Return null if such a loop-independent quantity cannot be computed. 627 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer, 628 scf::ForOp forOp) { 629 MLIRContext *ctx = forOp->getContext(); 630 AffineExpr iv, lb, step; 631 bindDims(ctx, iv, lb); 632 bindSymbols(ctx, step); 633 Value ivVal = forOp.getInductionVar(), 634 lbVal = computeLoopIndependentLowerBound(b, outer, forOp.lowerBound()), 635 stepVal = forOp.step(); 636 if (!ivVal || !lbVal || !stepVal) 637 return Value(); 638 auto loc = forOp->getLoc(); 639 return b.create<AffineApplyOp>(loc, (iv - lb).ceilDiv(step), 640 ValueRange{ivVal, lbVal, stepVal}); 641 } 642 643 /// Ensure prerequisites that guarantee pad op hoisting can occur. 644 /// Return failure in the cases when we cannot perform hoisting; i.e. if either: 645 /// 1. There exists a use of `padTensorOp` that is not a linalg input operand. 646 /// 2. There isn't an enclosing `outermostEnclosingForOp` loop. 647 /// 3. There exists an op with a region that is dominated by 648 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a 649 /// LinalgOp. 650 /// 4. There exists an op with side effects that is dominated by 651 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface. 652 /// 5. The lower bound, upper bound and step of all the loops involved in the 653 /// hoisting can be 654 /// 655 /// While ensuring prerequisites: 656 /// 1. Fill the `backwardSlice` to contain the topologically sorted ops 657 /// dominated by `outermostEnclosingForOp`. 658 /// 2. Fill the `packingLoops` to contain only the enclosing loops of 659 /// `backwardSlice` whose IV is actually used in computing padding. Loops that 660 /// remain in `backwardSlice` but that are not in `packingLoops` are 661 /// dimensions of reuse. 662 static LogicalResult 663 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels, 664 SetVector<Operation *> &backwardSlice, 665 SetVector<Operation *> &packingLoops, 666 SmallVector<Value> &dynamicTensorSizes) { 667 // Bail on any use that isn't an input of a Linalg op. 668 // Hoisting of inplace updates happens after vectorization. 669 for (OpOperand &use : padTensorOp.result().getUses()) { 670 auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner()); 671 if (!linalgUser || !linalgUser.isInputTensor(&use)) 672 return failure(); 673 } 674 675 // Get at most nLevels of enclosing loops. 676 SmallVector<LoopLikeOpInterface> reverseEnclosingLoops; 677 Operation *outermostEnclosingForOp = nullptr, 678 *nextEnclosingForOp = 679 padTensorOp->getParentOfType<LoopLikeOpInterface>(); 680 while (nLevels-- > 0 && nextEnclosingForOp) { 681 outermostEnclosingForOp = nextEnclosingForOp; 682 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 683 nextEnclosingForOp = 684 nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>(); 685 } 686 if (!outermostEnclosingForOp) 687 return failure(); 688 689 // Get the backwards slice from `padTensorOp` that is dominated by the 690 // outermost enclosing loop. 691 DominanceInfo domInfo(outermostEnclosingForOp); 692 getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, 693 [&](Operation *op) { 694 return domInfo.dominates(outermostEnclosingForOp, op); 695 }); 696 697 // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp. 698 if (llvm::any_of(backwardSlice, [](Operation *op) { 699 return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) && 700 !isa<LinalgOp>(op); 701 })) 702 return failure(); 703 704 // Filter out the loops whose induction variable is not used to compute the 705 // padded result. As a first approximation, just look for IVs that have no use 706 // in the backwardSlice. 707 // These are the dimensions of reuse that we can exploit to reduce the amount 708 // of work / memory. 709 // TODO: would this optimization compose better as a canonicalization? 710 for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) { 711 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); 712 if (!forOp) 713 continue; 714 for (Operation *user : forOp.getInductionVar().getUsers()) { 715 if (backwardSlice.contains(user)) { 716 packingLoops.insert(forOp); 717 break; 718 } 719 } 720 } 721 722 // Backward slice is a topologically sorted list of ops starting at 723 // `outermostEnclosingForOp`. 724 assert(outermostEnclosingForOp == backwardSlice.front()); 725 726 scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp); 727 if (llvm::any_of(packingLoops, [&](Operation *op) { 728 scf::ForOp forOp = cast<scf::ForOp>(op); 729 Value lb = forOp.lowerBound(), ub = forOp.upperBound(), 730 step = forOp.step(); 731 return !isDefinedOutsideOrConstant(outer, lb) || 732 !(isDefinedOutsideOrConstant(outer, ub) || 733 backwardsSliceOnlyHasOpsOfType<scf::ForOp, AffineMinOp, 734 AffineApplyOp>(outer, ub)) || 735 !isDefinedOutsideOrConstant(outer, step); 736 })) 737 return failure(); 738 739 // IP just before the outermost loop considered that we hoist above. 740 OpBuilder b(outermostEnclosingForOp); 741 dynamicTensorSizes = 742 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *op) { 743 return buildLoopTripCount(b, cast<scf::ForOp>(outermostEnclosingForOp), 744 cast<scf::ForOp>(op)); 745 })); 746 // Assert all loop trip counts can be computed. 747 if (!llvm::all_of(dynamicTensorSizes, [](Value v) { return v; })) 748 llvm_unreachable("loop independence prerequisite not met"); 749 return success(); 750 } 751 752 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, 753 unsigned nLoops) { 754 SmallVector<Value> dynamicTensorSizes; 755 SetVector<Operation *> backwardSlice, packingLoops; 756 if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops, 757 backwardSlice, packingLoops, 758 dynamicTensorSizes))) 759 return failure(); 760 761 // Update actual number of loops, which may be smaller. 762 nLoops = packingLoops.size(); 763 764 Location loc = padTensorOp->getLoc(); 765 RankedTensorType paddedTensorType = padTensorOp.getResultType(); 766 unsigned paddedRank = paddedTensorType.getRank(); 767 768 // Backward slice is a topologically sorted list of ops starting at 769 // `outermostEnclosingForOp`. 770 Operation *outermostEnclosingForOp = backwardSlice.front(); 771 // IP just before the outermost loop considered that we hoist above. 772 OpBuilder b(outermostEnclosingForOp); 773 774 // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize 775 // padding. 776 SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize); 777 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 778 // tensor. 779 llvm::append_range(packedShape, paddedTensorType.getShape()); 780 auto packedTensorType = 781 RankedTensorType::get(packedShape, paddedTensorType.getElementType()); 782 Value packedTensor = b.create<linalg::InitTensorOp>( 783 loc, dynamicTensorSizes, packedTensorType.getShape(), 784 packedTensorType.getElementType()); 785 786 // Clone the operations involved in the backward slice, iteratively stepping 787 // into the loops that we encounter. 788 // The implementation proceeds in a stack-like fashion: 789 // 1. Iteratively clone and step into the loops, pushing the `packedTensor` 790 // deeper in the stack. 791 // 2. Create a SubTensorInsert at the top of the stack. 792 // 3. Iteratively pop and yield the result of the SubTensorInsertOp across 793 // the cloned loops. 794 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; 795 clonedLoopIvs.reserve(nLoops); 796 leadingPackedTensorIndexings.reserve(nLoops); 797 BlockAndValueMapping bvm; 798 // Insert `padTensorOp` into the backwardSlice so we clone it too. 799 backwardSlice.insert(padTensorOp); 800 // Stack step 1. iteratively clone loops and push `packedTensor`. 801 for (Operation *op : backwardSlice) { 802 // Specifically sit out in the subtenso(packedTensor) case: this is the 803 // piece we seek to replace. 804 if (auto subTensor = dyn_cast<SubTensorOp>(op)) 805 if (bvm.lookupOrDefault(subTensor.source()) == packedTensor) 806 continue; 807 auto effects = dyn_cast<MemoryEffectOpInterface>(op); 808 bool hasNoEffects = !effects || effects.hasNoEffect(); 809 if (hasNoEffects && 810 (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) { 811 b.clone(*op, bvm); 812 continue; 813 } 814 // TODO: support more cases as they appear. 815 auto forOp = dyn_cast<scf::ForOp>(op); 816 assert(forOp && "Expected scf::ForOp when hoisting pad ops"); 817 // Unused loop, just skip it. 818 if (!packingLoops.contains(forOp)) 819 continue; 820 821 auto clonedForOp = 822 b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()), 823 bvm.lookupOrDefault(forOp.upperBound()), 824 bvm.lookupOrDefault(forOp.step()), packedTensor); 825 // Map the induction var, region args and results to the `clonedForOp`. 826 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); 827 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs()); 828 bvm.map(forOp.getResults(), clonedForOp.getResults()); 829 assert(clonedForOp->getNumRegions() == 1); 830 clonedLoopIvs.push_back(clonedForOp.getInductionVar()); 831 832 b.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); 833 Value loopIndependentIterationCount = buildLoopIterationCount( 834 b, cast<scf::ForOp>(outermostEnclosingForOp), clonedForOp); 835 // Assert the loop-independent iteration count can be computed. 836 if (!loopIndependentIterationCount) 837 llvm_unreachable("loop independence prerequisite not met"); 838 leadingPackedTensorIndexings.push_back(loopIndependentIterationCount); 839 packedTensor = clonedForOp.getRegionIterArgs().front(); 840 } 841 842 // Stack step 2. create SubTensorInsertOp at the top of the stack. 843 // offsets = [clonedLoopIvs, 0 .. 0]. 844 SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(), 845 leadingPackedTensorIndexings.end()); 846 offsets.append(paddedRank, b.getIndexAttr(0)); 847 // sizes = [1 .. 1, paddedShape]. 848 SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1)); 849 for (int64_t sz : paddedTensorType.getShape()) { 850 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 851 // tensor. 852 assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); 853 sizes.push_back(b.getIndexAttr(sz)); 854 } 855 // strides = [1 .. 1]. 856 SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1)); 857 858 Value inserted = 859 b.create<SubTensorInsertOp>(loc, bvm.lookup(padTensorOp.result()), 860 packedTensor, offsets, sizes, strides); 861 862 // Stack step 3. iteratively pop the stack and propagate the yield. 863 Value valueToYield = inserted; 864 for (Value iv : llvm::reverse(clonedLoopIvs)) { 865 auto forOp = scf::getForInductionVarOwner(iv); 866 b.setInsertionPointToEnd(&forOp.getRegion().front()); 867 b.create<scf::YieldOp>(loc, valueToYield); 868 valueToYield = forOp.getResult(0); 869 } 870 871 // Now the packed tensor is ready, replace the original padding op by a 872 // 1x..x1 SubTensor [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. 873 b.setInsertionPoint(padTensorOp); 874 SmallVector<Value> loopIterationCounts = 875 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { 876 return buildLoopIterationCount( 877 b, cast<scf::ForOp>(outermostEnclosingForOp), 878 cast<scf::ForOp>(loop)); 879 })); 880 // Assert all loop iteration counts can be computed. 881 if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; })) 882 llvm_unreachable("loop independence prerequisite not met"); 883 // offsets = [originalLoopIvs, 0 .. 0]. 884 offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); 885 offsets.append(paddedRank, b.getIndexAttr(0)); 886 // sizes = [1 .. 1, paddedShape] (definedabove). 887 // strides = [1 .. 1] (defined above) 888 packedTensor = 889 scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); 890 padTensorOp.replaceAllUsesWith( 891 b.create<SubTensorOp>(loc, padTensorOp.getResultType(), packedTensor, 892 offsets, sizes, strides) 893 ->getResult(0)); 894 895 Operation *toErase = padTensorOp; 896 897 // Make the newly cloned `padTensorOp` available to the caller. 898 padTensorOp = 899 cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp()); 900 901 toErase->erase(); 902 903 return success(); 904 } 905