1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/AffineStructures.h" 16 #include "mlir/Analysis/SliceAnalysis.h" 17 #include "mlir/Dialect/Affine/Utils.h" 18 #include "mlir/Dialect/Linalg/Analysis/ConstraintsSet.h" 19 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/SCF/SCF.h" 22 #include "mlir/Dialect/SCF/Utils.h" 23 #include "mlir/Dialect/StandardOps/IR/Ops.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/VectorOps.h" 26 #include "mlir/Dialect/Vector/VectorUtils.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/Dominance.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/LoopUtils.h" 31 #include "llvm/ADT/StringRef.h" 32 #include "llvm/Support/Debug.h" 33 34 using llvm::dbgs; 35 36 #define DEBUG_TYPE "linalg-hoisting" 37 38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 39 40 using namespace mlir; 41 using namespace mlir::linalg; 42 43 namespace { 44 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 45 /// instructions that need to be hoisted too. 46 struct HoistableWrite { 47 vector::TransferWriteOp transferWriteOp; 48 tensor::InsertSliceOp insertSliceOp; 49 }; 50 /// Represents a unit of hoistable TransferReadOp. This may comprise other 51 /// instructions that need to be hoisted too. 52 struct HoistableRead { 53 vector::TransferReadOp transferReadOp; 54 tensor::ExtractSliceOp extractSliceOp; 55 }; 56 } // namespace 57 58 /// Return true if op1 and op2 are the same constant or the same SSA value. 59 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 60 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 61 Attribute attr = ofr.dyn_cast<Attribute>(); 62 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 63 if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>()) 64 attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue(); 65 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 66 return intAttr.getValue().getSExtValue(); 67 return llvm::None; 68 }; 69 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 70 if (cst1 && cst2 && *cst1 == *cst2) 71 return true; 72 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 73 return v1 && v2 && v1 == v2; 74 } 75 76 /// Return true is all offsets, sizes and strides are equal. 77 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, 78 tensor::InsertSliceOp si) { 79 if (s.static_offsets().size() != si.static_offsets().size()) 80 return false; 81 if (s.static_sizes().size() != si.static_sizes().size()) 82 return false; 83 if (s.static_strides().size() != si.static_strides().size()) 84 return false; 85 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 86 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 87 return false; 88 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 89 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 90 return false; 91 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 92 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 93 return false; 94 return true; 95 } 96 97 /// Look for a HoistableRead, in the given tensor uses, accessing the same 98 /// offset as the HoistableWrite. 99 static HoistableRead findMatchingTransferRead(HoistableWrite write, 100 Value srcTensor) { 101 assert(write.transferWriteOp && 102 "expected hoistable write to have a .transfer_write"); 103 104 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 105 << *write.transferWriteOp.getOperation() << "\n"); 106 if (write.insertSliceOp) 107 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " 108 << *write.insertSliceOp.getOperation() << "\n"); 109 110 for (Operation *user : srcTensor.getUsers()) { 111 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 112 << "\n"); 113 114 // If HoistableWrite involves a InsertSliceOp, we need to find a 115 // matching ExtractSliceOp. 116 tensor::ExtractSliceOp sliceOp; 117 Operation *maybeTransferReadUser = user; 118 if (write.insertSliceOp) { 119 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 120 if (!sliceOp || sliceOp.getResult().getType() != 121 write.insertSliceOp.source().getType()) 122 continue; 123 124 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 125 << *sliceOp << " vs " << *write.insertSliceOp << "\n"); 126 if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) 127 continue; 128 129 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 130 // If we got here, sliceOp is hoistable iff it has exactly 2 uses: 131 // 1. the transfer_write we want to hoist. 132 // 2. a matching transfer_read. 133 // Anything else, we skip. 134 bool skip = false; 135 Operation *otherUser = nullptr; 136 for (Operation *u : sliceOp->getUsers()) { 137 if (u == write.transferWriteOp) 138 continue; 139 if (otherUser) { 140 skip = true; 141 break; 142 } 143 otherUser = u; 144 } 145 if (skip || !otherUser) 146 continue; 147 maybeTransferReadUser = otherUser; 148 } 149 150 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 151 << "\n"); 152 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 153 if (read && read.indices() == write.transferWriteOp.indices() && 154 read.getVectorType() == write.transferWriteOp.getVectorType()) 155 return HoistableRead{read, sliceOp}; 156 } 157 return HoistableRead(); 158 } 159 160 /// Check if the chunk of data inserted by the HoistableWrite are read by any 161 /// other op than the HoistableRead candidate. 162 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 163 HoistableRead candidateRead, 164 BlockArgument tensorArg) { 165 // Make sure none of the other uses read the part of the tensor modified 166 // by the transfer_write. 167 llvm::SmallVector<Value::use_range, 1> uses; 168 uses.push_back(tensorArg.getUses()); 169 while (!uses.empty()) { 170 for (OpOperand &use : uses.pop_back_val()) { 171 Operation *user = use.getOwner(); 172 // Skip the candidate use, only inspect the "other" uses. 173 if (user == candidateRead.transferReadOp || 174 user == candidateRead.extractSliceOp || 175 user == write.transferWriteOp || user == write.insertSliceOp) 176 continue; 177 // Consider all transitive uses through a extract_slice / insert_slice. 178 // TODO: atm we just bail because a stronger analysis is needed for these 179 // cases. 180 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user)) 181 return true; 182 // Consider all transitive uses through a vector.transfer_write. 183 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 184 uses.push_back(writeUser->getResult(0).getUses()); 185 continue; 186 } 187 // Consider all nested uses through an scf::ForOp. We may have 188 // pass-through tensor arguments left from previous level of 189 // hoisting. 190 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 191 Value arg = forUser.getLoopBody().getArgument( 192 use.getOperandNumber() - forUser.getNumControlOperands() + 193 /*iv value*/ 1); 194 uses.push_back(arg.getUses()); 195 continue; 196 } 197 // Follow the use yield as long as it doesn't escape the original 198 // region. 199 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 200 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 201 yieldUser->getParentOp())) { 202 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 203 uses.push_back(ret.getUses()); 204 continue; 205 } 206 auto read = dyn_cast<vector::TransferReadOp>(user); 207 if (!read || !isDisjointTransferIndices( 208 cast<VectorTransferOpInterface>(read.getOperation()), 209 cast<VectorTransferOpInterface>( 210 write.transferWriteOp.getOperation()))) { 211 return true; 212 } 213 } 214 } 215 return false; 216 } 217 218 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 219 /// Return the null HoistableWrite() if it is not comprised of a 220 /// vector.transfer_write + optional insert_slice or if any of the indexings 221 /// is `forOp`-dependent. 222 static HoistableWrite 223 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 224 OpOperand &yieldOperand) { 225 Value v = yieldOperand.get(); 226 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 227 // Indexing must not depend on `forOp`. 228 for (Value operand : write.indices()) 229 if (!forOp.isDefinedOutsideOfLoop(operand)) 230 return HoistableWrite(); 231 232 return HoistableWrite{write, nullptr}; 233 } 234 235 if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) { 236 // Inserted slice must come from vector.transfer_write. 237 auto write = 238 insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>(); 239 if (!write) 240 return HoistableWrite(); 241 242 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 243 auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>(); 244 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 245 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 246 return HoistableWrite(); 247 248 // Indexing inserted into must not depend on `forOp`. 249 for (Value operand : insertSliceOp->getOperands().drop_front( 250 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) 251 if (!forOp.isDefinedOutsideOfLoop(operand)) 252 return HoistableWrite(); 253 254 return HoistableWrite{write, insertSliceOp}; 255 } 256 257 return HoistableWrite(); 258 } 259 260 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 261 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 262 BlockArgument tensorBBArg) { 263 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 264 assert(read.transferReadOp && write.transferWriteOp && 265 "expected transfer_read and transfer_write ops to be set"); 266 assert(((read.extractSliceOp && write.insertSliceOp) || 267 (!read.extractSliceOp && !write.insertSliceOp)) && 268 "expected matching extract_slice / insert_slice"); 269 LLVM_DEBUG(DBGS() << "In forOp:\n" 270 << *forOp.getOperation() 271 << "\nHoist: " << *read.transferReadOp.getOperation() 272 << "\nHoist: " << *write.transferWriteOp.getOperation() 273 << "\nInvolving: " << tensorBBArg << "\n"); 274 275 // If a read slice is present, hoist it. 276 if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp}))) 277 llvm_unreachable("Unexpected failure moving extract_slice out of loop"); 278 279 // Hoist the transfer_read op. 280 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 281 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 282 283 // TODO: don't hardcode /*numIvs=*/1. 284 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 285 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 286 287 // Update the source tensor. 288 if (read.extractSliceOp) 289 read.extractSliceOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 290 else 291 read.transferReadOp.sourceMutable().assign(forOp.initArgs()[initArgNumber]); 292 293 // Hoist write after. 294 if (write.insertSliceOp) 295 write.insertSliceOp->moveAfter(forOp); 296 write.transferWriteOp->moveAfter(forOp); 297 298 // Update the yield. 299 auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator()); 300 if (write.insertSliceOp) 301 yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest()); 302 else 303 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 304 305 // Rewrite `loop` with additional new yields. 306 OpBuilder b(read.transferReadOp); 307 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 308 write.transferWriteOp.vector()); 309 // Transfer write has been hoisted, need to update the vector and tensor 310 // source. Replace the result of the loop to use the new tensor created 311 // outside the loop. 312 // Depending on whether a insert_slice is present or not, it carries the 313 // update on the tensor operands. 314 if (write.insertSliceOp) { 315 newForOp.getResult(initArgNumber) 316 .replaceAllUsesWith(write.insertSliceOp.getResult()); 317 write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result()); 318 write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); 319 } else { 320 newForOp.getResult(initArgNumber) 321 .replaceAllUsesWith(write.transferWriteOp.getResult(0)); 322 write.transferWriteOp.sourceMutable().assign( 323 newForOp.getResult(initArgNumber)); 324 } 325 326 // Always update with the newly yield tensor and vector. 327 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 328 } 329 330 // To hoist transfer op on tensor the logic can be significantly simplified 331 // compared to the case on buffer. The transformation follows this logic: 332 // 1. Look for transfer_write with a single use from ForOp yield 333 // 2. Check the uses of the matching block argument and look for a transfer_read 334 // with the same indices. 335 // 3. Check that all the other uses of the tensor argument are either disjoint 336 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 337 // the new tensor has the same restrictions on its uses. 338 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 339 // After this transformation the scf.forOp may have unused arguments that can be 340 // remove by the canonicalization pass. 341 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 342 bool changed = true; 343 while (changed) { 344 changed = false; 345 func.walk([&](scf::ForOp forOp) { 346 Operation *yield = forOp.getBody()->getTerminator(); 347 for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { 348 OpOperand &ret = yield->getOpOperand(it.index()); 349 HoistableWrite write = 350 getLoopInvariantTransferWriteOpDefining(forOp, ret); 351 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 352 continue; 353 LLVM_DEBUG(dbgs() << "\n"; 354 DBGS() << "Candidate write for hoisting: " 355 << *write.transferWriteOp.getOperation() << "\n"); 356 if (write.insertSliceOp) 357 LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " 358 << *write.insertSliceOp.getOperation() << "\n"); 359 if (llvm::any_of(write.transferWriteOp.indices(), 360 [&forOp](Value index) { 361 return !forOp.isDefinedOutsideOfLoop(index); 362 })) 363 continue; 364 // Find a read with the same type and indices. 365 HoistableRead matchingRead = 366 findMatchingTransferRead(write, it.value()); 367 // Make sure none of the other uses read the part of the tensor modified 368 // by the transfer_write. 369 if (!matchingRead.transferReadOp || 370 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 371 continue; 372 373 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 374 hoistReadWrite(matchingRead, write, it.value()); 375 changed = true; 376 forOp.erase(); 377 378 // Need to interrupt and restart: erasing the loop messes up the walk. 379 return WalkResult::interrupt(); 380 } 381 return WalkResult::advance(); 382 }); 383 // Apply canonicalization so the newForOp + yield folds immediately, thus 384 // cleaning up the IR and potentially enabling more hoisting. 385 if (changed) { 386 RewritePatternSet patterns(func->getContext()); 387 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 388 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 389 } 390 } 391 } 392 393 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 394 bool changed = true; 395 while (changed) { 396 changed = false; 397 398 func.walk([&](vector::TransferReadOp transferRead) { 399 if (!transferRead.getShapedType().isa<MemRefType>()) 400 return WalkResult::advance(); 401 402 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 403 << *transferRead.getOperation() << "\n"); 404 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 405 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 406 << "\n"); 407 if (!loop) 408 return WalkResult::advance(); 409 410 if (failed(moveLoopInvariantCode( 411 cast<LoopLikeOpInterface>(loop.getOperation())))) 412 llvm_unreachable( 413 "Unexpected failure to move invariant code out of loop"); 414 415 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 416 << "\n"); 417 418 SetVector<Operation *> forwardSlice; 419 getForwardSlice(transferRead.getOperation(), &forwardSlice); 420 421 // Look for the last TransferWriteOp in the forwardSlice of 422 // `transferRead` that operates on the same memref. 423 vector::TransferWriteOp transferWrite; 424 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 425 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 426 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 427 continue; 428 transferWrite = candidateWrite; 429 } 430 431 // All operands of the TransferRead must be defined outside of the loop. 432 for (auto operand : transferRead.getOperands()) 433 if (!loop.isDefinedOutsideOfLoop(operand)) 434 return WalkResult::advance(); 435 436 // Only hoist transfer_read / transfer_write pairs for now. 437 if (!transferWrite) 438 return WalkResult::advance(); 439 440 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 441 << "\n"); 442 443 // Approximate aliasing by checking that: 444 // 1. indices are the same, 445 // 2. no other operations in the loop access the same memref except 446 // for transfer_read/transfer_write accessing statically disjoint 447 // slices. 448 if (transferRead.indices() != transferWrite.indices() && 449 transferRead.getVectorType() == transferWrite.getVectorType()) 450 return WalkResult::advance(); 451 452 // TODO: may want to memoize this information for performance but it 453 // likely gets invalidated often. 454 DominanceInfo dom(loop); 455 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 456 return WalkResult::advance(); 457 for (auto &use : transferRead.source().getUses()) { 458 if (!dom.properlyDominates(loop, use.getOwner())) 459 continue; 460 if (use.getOwner() == transferRead.getOperation() || 461 use.getOwner() == transferWrite.getOperation()) 462 continue; 463 if (auto transferWriteUse = 464 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 465 if (!isDisjointTransferSet( 466 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 467 cast<VectorTransferOpInterface>( 468 transferWriteUse.getOperation()))) 469 return WalkResult::advance(); 470 } else if (auto transferReadUse = 471 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 472 if (!isDisjointTransferSet( 473 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 474 cast<VectorTransferOpInterface>( 475 transferReadUse.getOperation()))) 476 return WalkResult::advance(); 477 } else { 478 // Unknown use, we cannot prove that it doesn't alias with the 479 // transferRead/transferWrite operations. 480 return WalkResult::advance(); 481 } 482 } 483 484 // Hoist read before. 485 if (failed(loop.moveOutOfLoop({transferRead}))) 486 llvm_unreachable( 487 "Unexpected failure to move transfer read out of loop"); 488 489 // Hoist write after. 490 transferWrite->moveAfter(loop); 491 492 // Rewrite `loop` with new yields by cloning and erase the original loop. 493 OpBuilder b(transferRead); 494 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 495 transferWrite.vector()); 496 497 // Transfer write has been hoisted, need to update the written value to 498 // the value yielded by the newForOp. 499 transferWrite.vector().replaceAllUsesWith( 500 newForOp.getResults().take_back()[0]); 501 502 changed = true; 503 loop.erase(); 504 // Need to interrupt and restart because erasing the loop messes up the 505 // walk. 506 return WalkResult::interrupt(); 507 }); 508 } 509 } 510 511 /// Return success if `v` is a value that is only transitively defined by ops of 512 /// type in `OpTypeList`. 513 template <typename... OpTypeList> 514 static bool backwardsSliceOnlyHasOpsOfType(scf::ForOp outerLimit, Value v) { 515 // Compute a backward slice up to, but not including, `outerLimit`. 516 SetVector<Operation *> backwardSlice; 517 getBackwardSlice(v, &backwardSlice, [&](Operation *op) { 518 return outerLimit->isProperAncestor(op); 519 }); 520 // Traverse the backward slice and ensure we can perform the computation to 521 // hoist. 522 for (Operation *op : backwardSlice) { 523 if (isa<OpTypeList...>(op)) 524 continue; 525 LLVM_DEBUG(DBGS() << "Abort: unadmissible op in slice " << *op << "\n"); 526 return false; 527 } 528 return true; 529 } 530 531 bool isDefinedOutsideOrConstant(scf::ForOp outer, Value v) { 532 return outer.isDefinedOutsideOfLoop(v) || v.getDefiningOp<ConstantOp>(); 533 } 534 535 /// Return the current iteration number in the loop (iv - lb).ceilDiv(step). 536 /// The returned Value is guaranteed not to depend on any loop comprised in 537 /// [`outer`, `forOp`]. 538 /// Return null if such a loop-independent quantity cannot be computed. 539 static Value buildLoopIterationCount(OpBuilder &b, scf::ForOp outer, 540 scf::ForOp forOp) { 541 MLIRContext *ctx = forOp->getContext(); 542 AffineExpr iv, lb, step; 543 bindDims(ctx, iv, lb); 544 bindSymbols(ctx, step); 545 if (!isDefinedOutsideOrConstant(outer, forOp.lowerBound()) || 546 !isDefinedOutsideOrConstant(outer, forOp.step())) 547 return Value(); 548 Value ivVal = forOp.getInductionVar(), lbVal = forOp.lowerBound(), 549 stepVal = forOp.step(); 550 auto loc = forOp->getLoc(); 551 return b.createOrFold<AffineApplyOp>(loc, (iv - lb).ceilDiv(step), 552 ValueRange{ivVal, lbVal, stepVal}); 553 } 554 555 /// Given a set of loops, assumed to be scf::ForOp, create a constraint set 556 /// containing the inequalities `iv - lb >= 0` and `-iv + ub - 1 >= 0` for each 557 /// loop. 558 static ConstraintsSet initLoopIvsAndBounds(ArrayRef<Operation *> loops) { 559 ConstraintsSet constraints; 560 for (Operation *op : loops) 561 constraints.addDimId(constraints.getNumDimIds(), 562 cast<scf::ForOp>(op).getInductionVar()); 563 for (Operation *op : loops) 564 constraints.addDimId(constraints.getNumDimIds(), 565 cast<scf::ForOp>(op).lowerBound()); 566 for (Operation *op : loops) 567 constraints.addDimId(constraints.getNumDimIds(), 568 cast<scf::ForOp>(op).upperBound()); 569 unsigned numLoops = loops.size(); 570 for (unsigned ivIdx = 0, e = numLoops; ivIdx < e; ++ivIdx) { 571 // iv - lb >= 0 572 SmallVector<int64_t, 8> ineqLb(constraints.getNumCols(), 0); 573 ineqLb[ivIdx] = 1; 574 ineqLb[ivIdx + numLoops] = -1; 575 // -iv + ub >= 0 576 SmallVector<int64_t, 8> ineqUb(constraints.getNumCols(), 0); 577 ineqUb[ivIdx] = -1; 578 ineqUb[ivIdx + 2 * numLoops] = 1; 579 ineqUb[constraints.getNumCols() - 1] = -1; 580 constraints.addInequality(ineqLb); 581 constraints.addInequality(ineqUb); 582 } 583 return constraints; 584 } 585 586 /// For each loop in `loops`, determine the ops involved in the construction of 587 /// its upper bound---up to the outerLimit loop--- and fold them as new 588 /// inequalities in the constraint set. 589 /// This is achieved by computing the backwardSlice of the loop's upper bound 590 /// and iteratively folding each op in reverse topological order to guarantee 591 /// use-def ordering. 592 /// As operations are folded in, their result is projected out of the 593 /// constraints set. 594 /// The following operations are supported: 595 /// - scf::ForOp are simply skipped. 596 /// - AffineApplyOp are composed to replace the result by an equality. 597 /// - AffineMinOp are composed by adding each entry as an upper bound. 598 /// If any other operation is met, return failure. 599 // TODO: extend on a per-need basis. 600 static LogicalResult 601 foldUpperBoundsIntoConstraintsSet(ConstraintsSet &constraints, 602 scf::ForOp outerLimit, 603 ArrayRef<Operation *> loops) { 604 SetVector<Value> toProjectOut; 605 for (Operation *loop : loops) { 606 auto ub = cast<scf::ForOp>(loop).upperBound(); 607 if (isDefinedOutsideOrConstant(outerLimit, ub)) 608 continue; 609 610 // Compute a backward slice up to, but not including, `outerLimit`. 611 SetVector<Operation *> backwardSlice; 612 getBackwardSlice(ub, &backwardSlice, [&](Operation *op) { 613 return outerLimit->isProperAncestor(op); 614 }); 615 backwardSlice.insert(ub.getDefiningOp()); 616 617 // Iterate over all ops in the slice and compose them in the constraints. 618 for (Operation *op : llvm::reverse(backwardSlice)) { 619 if (!isa<scf::ForOp, AffineApplyOp, AffineMinOp>(op)) 620 return failure(); 621 if (isa<scf::ForOp>(op)) 622 continue; 623 // Ensure there is a 624 auto ensureIdFailed = [&](Value v) { 625 return failed(constraints.ensureIdOfType(v, /*asDim=*/true)); 626 }; 627 628 // Ensure all ids exist and add results for later projection. 629 if (llvm::any_of(op->getResults(), ensureIdFailed) || 630 llvm::any_of(op->getOperands(), ensureIdFailed)) 631 return failure(); 632 633 // All supported ops have 1 result. 634 // TODO: extend when needed. 635 toProjectOut.insert(op->getResult(0)); 636 637 // Compose supported ops. 638 if (auto affineApplyOp = dyn_cast<AffineApplyOp>(op)) { 639 if (failed(constraints.composeAffineApply(affineApplyOp.getResult(), 640 affineApplyOp.getAffineMap(), 641 affineApplyOp.getOperands()))) 642 return failure(); 643 continue; 644 } 645 auto affineMinOp = cast<AffineMinOp>(op); 646 if (failed(constraints.composeMin(affineMinOp.getResult(), 647 affineMinOp.getAffineMap(), 648 affineMinOp.operands()))) 649 return failure(); 650 } 651 } 652 for (Value v : toProjectOut) 653 constraints.projectOut(v); 654 return success(); 655 } 656 657 /// Compute dynamic tensor sizes, independent of any value defined inside 658 /// `outer` and such that every n-D iteration of the packingLoops has its own 659 /// space (so that each packed buffer has a storage location). This is achieved 660 /// by computing the extent for each of the packing loops. 661 static LogicalResult computeBounds(scf::ForOp outer, 662 ArrayRef<Operation *> packingLoops, 663 SmallVector<AffineMap> &lbs, 664 SmallVector<AffineMap> &ubs) { 665 // Packing loop IVs are introduced as the first positions. 666 ConstraintsSet constraints = initLoopIvsAndBounds(packingLoops); 667 if (failed( 668 foldUpperBoundsIntoConstraintsSet(constraints, outer, packingLoops))) 669 return failure(); 670 // Compute the bounds of the first positions, assuming the others are fixed. 671 constraints.getSliceBounds(/*pos=*/0, /*num=*/packingLoops.size(), 672 outer->getContext(), &lbs, &ubs); 673 return success(); 674 } 675 676 /// Ensure prerequisites that guarantee pad op hoisting can occur. 677 /// Return failure in the cases when we cannot perform hoisting; i.e. if either: 678 /// 1. There exists a use of `padTensorOp` that is not a linalg input operand. 679 /// 2. There isn't an enclosing `outermostEnclosingForOp` loop. 680 /// 3. There exists an op with a region that is dominated by 681 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface or a 682 /// LinalgOp. 683 /// 4. There exists an op with side effects that is dominated by 684 /// `outermostEnclosingForOp` and that isn't a LoopLikeInterface. 685 /// 5. The lower bound, upper bound and step of all the loops involved in the 686 /// hoisting can be 687 /// 688 /// While ensuring prerequisites: 689 /// 1. Fill the `backwardSlice` to contain the topologically sorted ops 690 /// dominated by `outermostEnclosingForOp`. 691 /// 2. Fill the `packingLoops` to contain only the enclosing loops of 692 /// `backwardSlice` whose IV is actually used in computing padding. Loops that 693 /// remain in `backwardSlice` but that are not in `packingLoops` are 694 /// dimensions of reuse. 695 static LogicalResult 696 hoistPaddingOnTensorsPrerequisites(linalg::PadTensorOp padTensorOp, int nLevels, 697 SetVector<Operation *> &backwardSlice, 698 SetVector<Operation *> &packingLoops, 699 SmallVector<Value> &dynamicTensorSizes) { 700 // Bail on any use that isn't an input of a Linalg op. 701 // Hoisting of inplace updates happens after vectorization. 702 for (OpOperand &use : padTensorOp.result().getUses()) { 703 auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner()); 704 if (!linalgUser || !linalgUser.isInputTensor(&use)) 705 return failure(); 706 } 707 708 // Get at most nLevels of enclosing loops. 709 SmallVector<LoopLikeOpInterface> reverseEnclosingLoops; 710 Operation *outermostEnclosingForOp = nullptr, 711 *nextEnclosingForOp = 712 padTensorOp->getParentOfType<LoopLikeOpInterface>(); 713 while (nLevels-- > 0 && nextEnclosingForOp) { 714 outermostEnclosingForOp = nextEnclosingForOp; 715 reverseEnclosingLoops.push_back(outermostEnclosingForOp); 716 nextEnclosingForOp = 717 nextEnclosingForOp->getParentOfType<LoopLikeOpInterface>(); 718 } 719 if (!outermostEnclosingForOp) 720 return failure(); 721 722 // Get the backwards slice from `padTensorOp` that is dominated by the 723 // outermost enclosing loop. 724 DominanceInfo domInfo(outermostEnclosingForOp); 725 getBackwardSlice(padTensorOp.getOperation(), &backwardSlice, 726 [&](Operation *op) { 727 return domInfo.dominates(outermostEnclosingForOp, op); 728 }); 729 730 // Bail on any op with a region that is not a LoopLikeInterface or a LinalgOp. 731 if (llvm::any_of(backwardSlice, [](Operation *op) { 732 return op->getNumRegions() > 0 && !isa<LoopLikeOpInterface>(op) && 733 !isa<LinalgOp>(op); 734 })) 735 return failure(); 736 737 // Filter out the loops whose induction variable is not used to compute the 738 // padded result. As a first approximation, just look for IVs that have no use 739 // in the backwardSlice. 740 // These are the dimensions of reuse that we can exploit to reduce the amount 741 // of work / memory. 742 // TODO: would this optimization compose better as a canonicalization? 743 for (LoopLikeOpInterface loop : llvm::reverse(reverseEnclosingLoops)) { 744 auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); 745 if (!forOp) 746 continue; 747 for (Operation *user : forOp.getInductionVar().getUsers()) { 748 if (backwardSlice.contains(user)) { 749 packingLoops.insert(forOp); 750 break; 751 } 752 } 753 } 754 755 // Backward slice is a topologically sorted list of ops starting at 756 // `outermostEnclosingForOp`. 757 assert(outermostEnclosingForOp == backwardSlice.front()); 758 759 scf::ForOp outer = cast<scf::ForOp>(outermostEnclosingForOp); 760 761 ConstraintsSet constraints = initLoopIvsAndBounds(packingLoops.getArrayRef()); 762 if (failed(foldUpperBoundsIntoConstraintsSet(constraints, outer, 763 packingLoops.getArrayRef()))) 764 return failure(); 765 766 unsigned numLoops = packingLoops.size(); 767 SmallVector<AffineMap> lbs(numLoops), ubs(numLoops); 768 if (failed(computeBounds(outer, packingLoops.getArrayRef(), lbs, ubs))) 769 return failure(); 770 771 SmallVector<Value> allValues; 772 constraints.getAllValues(&allValues); 773 SmallVector<Value> allNonLoopValues(allValues.begin() + numLoops, 774 allValues.end()); 775 776 // For each packingLoop, create the extent by (ub - lb).ceilDiv(step). 777 // IP just before the outermost loop considered that we hoist above. 778 ImplicitLocOpBuilder b(outer->getLoc(), outer); 779 assert(packingLoops.size() == lbs.size() && "expected matching lb sizes"); 780 assert(packingLoops.size() == ubs.size() && "expected matching ub sizes"); 781 for (auto it : llvm::zip(packingLoops, lbs, ubs)) { 782 scf::ForOp loop = cast<scf::ForOp>(std::get<0>(it)); 783 AffineMap lbMap = std::get<1>(it); 784 AffineMap ubMap = std::get<2>(it); 785 SmallVector<Value> lbOperands(allNonLoopValues); 786 canonicalizeMapAndOperands(&lbMap, &lbOperands); 787 Value lbVal = b.createOrFold<AffineMaxOp>(lbMap, lbOperands); 788 789 SmallVector<Value> ubOperands(allNonLoopValues); 790 canonicalizeMapAndOperands(&ubMap, &ubOperands); 791 Value ubVal = b.createOrFold<AffineMinOp>(ubMap, ubOperands); 792 793 AffineExpr lb, ub, step; 794 bindDims(b.getContext(), lb, ub); 795 bindSymbols(b.getContext(), step); 796 Value res = b.createOrFold<AffineApplyOp>( 797 (ub - lb).ceilDiv(step), 798 ValueRange{lbVal, ubVal, cast<scf::ForOp>(loop).step()}); 799 800 dynamicTensorSizes.push_back(res); 801 } 802 803 return success(); 804 } 805 806 LogicalResult mlir::linalg::hoistPaddingOnTensors(PadTensorOp &padTensorOp, 807 unsigned nLoops) { 808 SmallVector<Value> dynamicTensorSizes; 809 SetVector<Operation *> backwardSlice, packingLoops; 810 if (failed(hoistPaddingOnTensorsPrerequisites(padTensorOp, nLoops, 811 backwardSlice, packingLoops, 812 dynamicTensorSizes))) 813 return failure(); 814 815 // Update actual number of loops, which may be smaller. 816 nLoops = packingLoops.size(); 817 818 Location loc = padTensorOp->getLoc(); 819 RankedTensorType paddedTensorType = padTensorOp.getResultType(); 820 unsigned paddedRank = paddedTensorType.getRank(); 821 822 // Backward slice is a topologically sorted list of ops starting at 823 // `outermostEnclosingForOp`. 824 Operation *outermostEnclosingForOp = backwardSlice.front(); 825 // IP just before the outermost loop considered that we hoist above. 826 OpBuilder b(outermostEnclosingForOp); 827 828 // Create the packed tensor<?x?x..?xpadded_shape> into which we amortize 829 // padding. 830 SmallVector<int64_t> packedShape(nLoops, ShapedType::kDynamicSize); 831 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 832 // tensor. 833 llvm::append_range(packedShape, paddedTensorType.getShape()); 834 auto packedTensorType = 835 RankedTensorType::get(packedShape, paddedTensorType.getElementType()); 836 Value packedTensor = b.create<linalg::InitTensorOp>( 837 loc, dynamicTensorSizes, packedTensorType.getShape(), 838 packedTensorType.getElementType()); 839 840 // Clone the operations involved in the backward slice, iteratively stepping 841 // into the loops that we encounter. 842 // The implementation proceeds in a stack-like fashion: 843 // 1. Iteratively clone and step into the loops, pushing the `packedTensor` 844 // deeper in the stack. 845 // 2. Create a InsertSliceOp at the top of the stack. 846 // 3. Iteratively pop and yield the result of the InsertSliceOp across 847 // the cloned loops. 848 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; 849 clonedLoopIvs.reserve(nLoops); 850 leadingPackedTensorIndexings.reserve(nLoops); 851 BlockAndValueMapping bvm; 852 // Insert `padTensorOp` into the backwardSlice so we clone it too. 853 backwardSlice.insert(padTensorOp); 854 // Stack step 1. iteratively clone loops and push `packedTensor`. 855 for (Operation *op : backwardSlice) { 856 // Specifically sit out in the extract_slice(packedTensor) case: this is the 857 // piece we seek to replace. 858 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) 859 if (bvm.lookupOrDefault(sliceOp.source()) == packedTensor) 860 continue; 861 auto effects = dyn_cast<MemoryEffectOpInterface>(op); 862 bool hasNoEffects = !effects || effects.hasNoEffect(); 863 if (hasNoEffects && 864 (op->getNumRegions() == 0 || isa<linalg::PadTensorOp>(op))) { 865 b.clone(*op, bvm); 866 continue; 867 } 868 // TODO: support more cases as they appear. 869 auto forOp = dyn_cast<scf::ForOp>(op); 870 assert(forOp && "Expected scf::ForOp when hoisting pad ops"); 871 // Unused loop, just skip it. 872 if (!packingLoops.contains(forOp)) 873 continue; 874 875 auto clonedForOp = 876 b.create<scf::ForOp>(loc, bvm.lookupOrDefault(forOp.lowerBound()), 877 bvm.lookupOrDefault(forOp.upperBound()), 878 bvm.lookupOrDefault(forOp.step()), packedTensor); 879 // Map the induction var, region args and results to the `clonedForOp`. 880 bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); 881 bvm.map(forOp.getRegionIterArgs(), clonedForOp.getRegionIterArgs()); 882 bvm.map(forOp.getResults(), clonedForOp.getResults()); 883 assert(clonedForOp->getNumRegions() == 1); 884 clonedLoopIvs.push_back(clonedForOp.getInductionVar()); 885 886 b.setInsertionPointToStart(&clonedForOp->getRegion(0).front()); 887 Value loopIndependentIterationCount = buildLoopIterationCount( 888 b, cast<scf::ForOp>(outermostEnclosingForOp), clonedForOp); 889 // Assert the loop-independent iteration count can be computed. 890 if (!loopIndependentIterationCount) 891 llvm_unreachable("loop independence prerequisite not met"); 892 leadingPackedTensorIndexings.push_back(loopIndependentIterationCount); 893 packedTensor = clonedForOp.getRegionIterArgs().front(); 894 } 895 896 // Stack step 2. create InsertSliceOp at the top of the stack. 897 // offsets = [clonedLoopIvs, 0 .. 0]. 898 SmallVector<OpFoldResult> offsets(leadingPackedTensorIndexings.begin(), 899 leadingPackedTensorIndexings.end()); 900 offsets.append(paddedRank, b.getIndexAttr(0)); 901 // sizes = [1 .. 1, paddedShape]. 902 SmallVector<OpFoldResult> sizes(nLoops, b.getIndexAttr(1)); 903 for (int64_t sz : paddedTensorType.getShape()) { 904 // TODO: go grab dims when necessary, for now PadTensorOp returns a static 905 // tensor. 906 assert(!ShapedType::isDynamic(sz) && "padded tensor needs static sizes"); 907 sizes.push_back(b.getIndexAttr(sz)); 908 } 909 // strides = [1 .. 1]. 910 SmallVector<OpFoldResult> strides(nLoops + paddedRank, b.getIndexAttr(1)); 911 912 Value inserted = 913 b.create<tensor::InsertSliceOp>(loc, bvm.lookup(padTensorOp.result()), 914 packedTensor, offsets, sizes, strides); 915 916 // Stack step 3. iteratively pop the stack and propagate the yield. 917 Value valueToYield = inserted; 918 for (Value iv : llvm::reverse(clonedLoopIvs)) { 919 auto forOp = scf::getForInductionVarOwner(iv); 920 b.setInsertionPointToEnd(&forOp.getRegion().front()); 921 b.create<scf::YieldOp>(loc, valueToYield); 922 valueToYield = forOp.getResult(0); 923 } 924 925 // Now the packed tensor is ready, replace the original padding op by a 926 // 1x..x1 slice [originalLoopIvs, 0 .. 0][1 .. 1, paddedShape][1 .. 1]. 927 b.setInsertionPoint(padTensorOp); 928 SmallVector<Value> loopIterationCounts = 929 llvm::to_vector<4>(llvm::map_range(packingLoops, [&](Operation *loop) { 930 return buildLoopIterationCount( 931 b, cast<scf::ForOp>(outermostEnclosingForOp), 932 cast<scf::ForOp>(loop)); 933 })); 934 // Assert all loop iteration counts can be computed. 935 if (llvm::any_of(loopIterationCounts, [](Value v) { return !v; })) 936 llvm_unreachable("loop independence prerequisite not met"); 937 // offsets = [originalLoopIvs, 0 .. 0]. 938 offsets.assign(loopIterationCounts.begin(), loopIterationCounts.end()); 939 offsets.append(paddedRank, b.getIndexAttr(0)); 940 // sizes = [1 .. 1, paddedShape] (definedabove). 941 // strides = [1 .. 1] (defined above) 942 packedTensor = 943 scf::getForInductionVarOwner(clonedLoopIvs.front())->getResult(0); 944 padTensorOp.replaceAllUsesWith( 945 b.create<tensor::ExtractSliceOp>(loc, padTensorOp.getResultType(), 946 packedTensor, offsets, sizes, strides) 947 ->getResult(0)); 948 949 Operation *toErase = padTensorOp; 950 951 // Make the newly cloned `padTensorOp` available to the caller. 952 padTensorOp = 953 cast<PadTensorOp>(bvm.lookup(padTensorOp.result()).getDefiningOp()); 954 955 toErase->erase(); 956 957 return success(); 958 } 959