1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 19 #include "mlir/Dialect/Func/IR/FuncOps.h" 20 #include "mlir/Dialect/Linalg/IR/Linalg.h" 21 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 22 #include "mlir/Dialect/SCF/IR/SCF.h" 23 #include "mlir/Dialect/SCF/Utils/Utils.h" 24 #include "mlir/Dialect/Tensor/IR/Tensor.h" 25 #include "mlir/Dialect/Vector/IR/VectorOps.h" 26 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 27 #include "mlir/IR/BuiltinOps.h" 28 #include "mlir/IR/Dominance.h" 29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 30 #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" 31 #include "llvm/ADT/StringRef.h" 32 #include "llvm/Support/Debug.h" 33 34 using llvm::dbgs; 35 36 #define DEBUG_TYPE "linalg-hoisting" 37 38 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 39 40 using namespace mlir; 41 using namespace mlir::linalg; 42 43 namespace { 44 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 45 /// instructions that need to be hoisted too. 46 struct HoistableWrite { 47 vector::TransferWriteOp transferWriteOp; 48 tensor::InsertSliceOp insertSliceOp; 49 }; 50 /// Represents a unit of hoistable TransferReadOp. This may comprise other 51 /// instructions that need to be hoisted too. 52 struct HoistableRead { 53 vector::TransferReadOp transferReadOp; 54 tensor::ExtractSliceOp extractSliceOp; 55 }; 56 } // namespace 57 58 /// Return true if op1 and op2 are the same constant or the same SSA value. 59 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 60 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 61 Attribute attr = ofr.dyn_cast<Attribute>(); 62 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 63 if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>()) 64 attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue(); 65 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 66 return intAttr.getValue().getSExtValue(); 67 return llvm::None; 68 }; 69 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 70 if (cst1 && cst2 && *cst1 == *cst2) 71 return true; 72 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 73 return v1 && v2 && v1 == v2; 74 } 75 76 /// Return true is all offsets, sizes and strides are equal. 77 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, 78 tensor::InsertSliceOp si) { 79 if (s.getStaticOffsets().size() != si.getStaticOffsets().size()) 80 return false; 81 if (s.getStaticSizes().size() != si.getStaticSizes().size()) 82 return false; 83 if (s.getStaticStrides().size() != si.getStaticStrides().size()) 84 return false; 85 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 86 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 87 return false; 88 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 89 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 90 return false; 91 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 92 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 93 return false; 94 return true; 95 } 96 97 /// Look for a HoistableRead, in the given tensor uses, accessing the same 98 /// offset as the HoistableWrite. 99 static HoistableRead findMatchingTransferRead(HoistableWrite write, 100 Value srcTensor) { 101 assert(write.transferWriteOp && 102 "expected hoistable write to have a .transfer_write"); 103 104 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 105 << *write.transferWriteOp.getOperation() << "\n"); 106 if (write.insertSliceOp) 107 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " 108 << *write.insertSliceOp.getOperation() << "\n"); 109 110 for (Operation *user : srcTensor.getUsers()) { 111 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 112 << "\n"); 113 114 // If HoistableWrite involves a InsertSliceOp, we need to find a 115 // matching ExtractSliceOp. 116 tensor::ExtractSliceOp sliceOp; 117 Operation *maybeTransferReadUser = user; 118 if (write.insertSliceOp) { 119 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 120 if (!sliceOp || sliceOp.getResult().getType() != 121 write.insertSliceOp.getSource().getType()) 122 continue; 123 124 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 125 << *sliceOp << " vs " << *write.insertSliceOp << "\n"); 126 if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) 127 continue; 128 129 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 130 // If we got here, sliceOp is hoistable iff it has exactly 2 uses: 131 // 1. the transfer_write we want to hoist. 132 // 2. a matching transfer_read. 133 // Anything else, we skip. 134 bool skip = false; 135 Operation *otherUser = nullptr; 136 for (Operation *u : sliceOp->getUsers()) { 137 if (u == write.transferWriteOp) 138 continue; 139 if (otherUser) { 140 skip = true; 141 break; 142 } 143 otherUser = u; 144 } 145 if (skip || !otherUser) 146 continue; 147 maybeTransferReadUser = otherUser; 148 } 149 150 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 151 << "\n"); 152 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 153 if (read && read.getIndices() == write.transferWriteOp.getIndices() && 154 read.getVectorType() == write.transferWriteOp.getVectorType()) 155 return HoistableRead{read, sliceOp}; 156 } 157 return HoistableRead(); 158 } 159 160 /// Check if the chunk of data inserted by the HoistableWrite are read by any 161 /// other op than the HoistableRead candidate. 162 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 163 HoistableRead candidateRead, 164 BlockArgument tensorArg) { 165 // Make sure none of the other uses read the part of the tensor modified 166 // by the transfer_write. 167 llvm::SmallVector<Value::use_range, 1> uses; 168 uses.push_back(tensorArg.getUses()); 169 while (!uses.empty()) { 170 for (OpOperand &use : uses.pop_back_val()) { 171 Operation *user = use.getOwner(); 172 // Skip the candidate use, only inspect the "other" uses. 173 if (user == candidateRead.transferReadOp || 174 user == candidateRead.extractSliceOp || 175 user == write.transferWriteOp || user == write.insertSliceOp) 176 continue; 177 // Consider all transitive uses through a extract_slice / insert_slice. 178 // TODO: atm we just bail because a stronger analysis is needed for these 179 // cases. 180 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user)) 181 return true; 182 // Consider all transitive uses through a vector.transfer_write. 183 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 184 uses.push_back(writeUser->getResult(0).getUses()); 185 continue; 186 } 187 // Consider all nested uses through an scf::ForOp. We may have 188 // pass-through tensor arguments left from previous level of 189 // hoisting. 190 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 191 Value arg = forUser.getLoopBody().getArgument( 192 use.getOperandNumber() - forUser.getNumControlOperands() + 193 /*iv value*/ 1); 194 uses.push_back(arg.getUses()); 195 continue; 196 } 197 // Follow the use yield as long as it doesn't escape the original 198 // region. 199 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 200 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 201 yieldUser->getParentOp())) { 202 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 203 uses.push_back(ret.getUses()); 204 continue; 205 } 206 auto read = dyn_cast<vector::TransferReadOp>(user); 207 if (!read || !vector::isDisjointTransferIndices( 208 cast<VectorTransferOpInterface>(read.getOperation()), 209 cast<VectorTransferOpInterface>( 210 write.transferWriteOp.getOperation()))) { 211 return true; 212 } 213 } 214 } 215 return false; 216 } 217 218 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 219 /// Return the null HoistableWrite() if it is not comprised of a 220 /// vector.transfer_write + optional insert_slice or if any of the indexings 221 /// is `forOp`-dependent. 222 static HoistableWrite 223 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 224 OpOperand &yieldOperand) { 225 Value v = yieldOperand.get(); 226 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 227 // Indexing must not depend on `forOp`. 228 for (Value operand : write.getIndices()) 229 if (!forOp.isDefinedOutsideOfLoop(operand)) 230 return HoistableWrite(); 231 232 return HoistableWrite{write, nullptr}; 233 } 234 235 if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) { 236 // Inserted slice must come from vector.transfer_write. 237 auto write = 238 insertSliceOp.getSource().getDefiningOp<vector::TransferWriteOp>(); 239 if (!write) 240 return HoistableWrite(); 241 242 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 243 auto bbArg = insertSliceOp.getDest().dyn_cast<BlockArgument>(); 244 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 245 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 246 return HoistableWrite(); 247 248 // Indexing inserted into must not depend on `forOp`. 249 for (Value operand : insertSliceOp->getOperands().drop_front( 250 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) 251 if (!forOp.isDefinedOutsideOfLoop(operand)) 252 return HoistableWrite(); 253 254 return HoistableWrite{write, insertSliceOp}; 255 } 256 257 return HoistableWrite(); 258 } 259 260 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 261 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 262 BlockArgument tensorBBArg) { 263 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 264 assert(read.transferReadOp && write.transferWriteOp && 265 "expected transfer_read and transfer_write ops to be set"); 266 assert(((read.extractSliceOp && write.insertSliceOp) || 267 (!read.extractSliceOp && !write.insertSliceOp)) && 268 "expected matching extract_slice / insert_slice"); 269 LLVM_DEBUG(DBGS() << "In forOp:\n" 270 << *forOp.getOperation() 271 << "\nHoist: " << *read.transferReadOp.getOperation() 272 << "\nHoist: " << *write.transferWriteOp.getOperation() 273 << "\nInvolving: " << tensorBBArg << "\n"); 274 275 // If a read slice is present, hoist it. 276 if (read.extractSliceOp) 277 forOp.moveOutOfLoop(read.extractSliceOp); 278 279 // Hoist the transfer_read op. 280 forOp.moveOutOfLoop(read.transferReadOp); 281 282 // TODO: don't hardcode /*numIvs=*/1. 283 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 284 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 285 286 // Update the source tensor. 287 if (read.extractSliceOp) 288 read.extractSliceOp.getSourceMutable().assign( 289 forOp.getInitArgs()[initArgNumber]); 290 else 291 read.transferReadOp.getSourceMutable().assign( 292 forOp.getInitArgs()[initArgNumber]); 293 294 // Hoist write after. 295 if (write.insertSliceOp) 296 write.insertSliceOp->moveAfter(forOp); 297 write.transferWriteOp->moveAfter(forOp); 298 299 // Update the yield. 300 auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator()); 301 if (write.insertSliceOp) 302 yieldOp->setOperand(initArgNumber, write.insertSliceOp.getDest()); 303 else 304 yieldOp->setOperand(initArgNumber, write.transferWriteOp.getSource()); 305 306 // Rewrite `loop` with additional new yields. 307 OpBuilder b(read.transferReadOp); 308 NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, 309 ArrayRef<BlockArgument> newBBArgs) { 310 return SmallVector<Value>{write.transferWriteOp.getVector()}; 311 }; 312 auto newForOp = replaceLoopWithNewYields( 313 b, forOp, read.transferReadOp.getVector(), yieldFn); 314 315 // Transfer write has been hoisted, need to update the vector and tensor 316 // source. Replace the result of the loop to use the new tensor created 317 // outside the loop. 318 // Depending on whether a insert_slice is present or not, it carries the 319 // update on the tensor operands. 320 if (write.insertSliceOp) { 321 newForOp.getResult(initArgNumber) 322 .replaceAllUsesWith(write.insertSliceOp.getResult()); 323 write.transferWriteOp.getSourceMutable().assign( 324 read.extractSliceOp.getResult()); 325 write.insertSliceOp.getDestMutable().assign( 326 read.extractSliceOp.getSource()); 327 } else { 328 newForOp.getResult(initArgNumber) 329 .replaceAllUsesWith(write.transferWriteOp.getResult()); 330 write.transferWriteOp.getSourceMutable().assign( 331 newForOp.getResult(initArgNumber)); 332 } 333 334 // Always update with the newly yield tensor and vector. 335 write.transferWriteOp.getVectorMutable().assign(newForOp.getResults().back()); 336 } 337 338 // To hoist transfer op on tensor the logic can be significantly simplified 339 // compared to the case on buffer. The transformation follows this logic: 340 // 1. Look for transfer_write with a single use from ForOp yield 341 // 2. Check the uses of the matching block argument and look for a transfer_read 342 // with the same indices. 343 // 3. Check that all the other uses of the tensor argument are either disjoint 344 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 345 // the new tensor has the same restrictions on its uses. 346 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 347 // After this transformation the scf.forOp may have unused arguments that can be 348 // remove by the canonicalization pass. 349 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(func::FuncOp func) { 350 bool changed = true; 351 while (changed) { 352 changed = false; 353 func.walk([&](scf::ForOp forOp) { 354 Operation *yield = forOp.getBody()->getTerminator(); 355 for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { 356 OpOperand &ret = yield->getOpOperand(it.index()); 357 HoistableWrite write = 358 getLoopInvariantTransferWriteOpDefining(forOp, ret); 359 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 360 continue; 361 LLVM_DEBUG(dbgs() << "\n"; 362 DBGS() << "Candidate write for hoisting: " 363 << *write.transferWriteOp.getOperation() << "\n"); 364 if (write.insertSliceOp) 365 LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " 366 << *write.insertSliceOp.getOperation() << "\n"); 367 if (llvm::any_of(write.transferWriteOp.getIndices(), 368 [&forOp](Value index) { 369 return !forOp.isDefinedOutsideOfLoop(index); 370 })) 371 continue; 372 // Find a read with the same type and indices. 373 HoistableRead matchingRead = 374 findMatchingTransferRead(write, it.value()); 375 // Make sure none of the other uses read the part of the tensor modified 376 // by the transfer_write. 377 if (!matchingRead.transferReadOp || 378 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 379 continue; 380 381 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 382 hoistReadWrite(matchingRead, write, it.value()); 383 changed = true; 384 forOp.erase(); 385 386 // Need to interrupt and restart: erasing the loop messes up the walk. 387 return WalkResult::interrupt(); 388 } 389 return WalkResult::advance(); 390 }); 391 // Apply canonicalization so the newForOp + yield folds immediately, thus 392 // cleaning up the IR and potentially enabling more hoisting. 393 if (changed) { 394 RewritePatternSet patterns(func->getContext()); 395 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 396 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 397 } 398 } 399 } 400 401 void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { 402 bool changed = true; 403 while (changed) { 404 changed = false; 405 // First move loop invariant ops outside of their loop. This needs to be 406 // done before as we cannot move ops without interrupting the function walk. 407 func.walk( 408 [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); 409 410 func.walk([&](vector::TransferReadOp transferRead) { 411 if (!transferRead.getShapedType().isa<MemRefType>()) 412 return WalkResult::advance(); 413 414 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 415 << *transferRead.getOperation() << "\n"); 416 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 417 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 418 << "\n"); 419 if (!loop) 420 return WalkResult::advance(); 421 422 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 423 << "\n"); 424 425 SetVector<Operation *> forwardSlice; 426 getForwardSlice(transferRead.getOperation(), &forwardSlice); 427 428 // Look for the last TransferWriteOp in the forwardSlice of 429 // `transferRead` that operates on the same memref. 430 vector::TransferWriteOp transferWrite; 431 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 432 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 433 if (!candidateWrite || 434 candidateWrite.getSource() != transferRead.getSource()) 435 continue; 436 transferWrite = candidateWrite; 437 } 438 439 // All operands of the TransferRead must be defined outside of the loop. 440 for (auto operand : transferRead.getOperands()) 441 if (!loop.isDefinedOutsideOfLoop(operand)) 442 return WalkResult::advance(); 443 444 // Only hoist transfer_read / transfer_write pairs for now. 445 if (!transferWrite) 446 return WalkResult::advance(); 447 448 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 449 << "\n"); 450 451 // Approximate aliasing by checking that: 452 // 1. indices are the same, 453 // 2. no other operations in the loop access the same memref except 454 // for transfer_read/transfer_write accessing statically disjoint 455 // slices. 456 if (transferRead.getIndices() != transferWrite.getIndices() && 457 transferRead.getVectorType() == transferWrite.getVectorType()) 458 return WalkResult::advance(); 459 460 // TODO: may want to memoize this information for performance but it 461 // likely gets invalidated often. 462 DominanceInfo dom(loop); 463 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 464 return WalkResult::advance(); 465 for (auto &use : transferRead.getSource().getUses()) { 466 if (!loop->isAncestor(use.getOwner())) 467 continue; 468 if (use.getOwner() == transferRead.getOperation() || 469 use.getOwner() == transferWrite.getOperation()) 470 continue; 471 if (auto transferWriteUse = 472 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 473 if (!vector::isDisjointTransferSet( 474 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 475 cast<VectorTransferOpInterface>( 476 transferWriteUse.getOperation()))) 477 return WalkResult::advance(); 478 } else if (auto transferReadUse = 479 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 480 if (!vector::isDisjointTransferSet( 481 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 482 cast<VectorTransferOpInterface>( 483 transferReadUse.getOperation()))) 484 return WalkResult::advance(); 485 } else { 486 // Unknown use, we cannot prove that it doesn't alias with the 487 // transferRead/transferWrite operations. 488 return WalkResult::advance(); 489 } 490 } 491 492 // Hoist read before. 493 loop.moveOutOfLoop(transferRead); 494 495 // Hoist write after. 496 transferWrite->moveAfter(loop); 497 498 // Rewrite `loop` with new yields by cloning and erase the original loop. 499 OpBuilder b(transferRead); 500 NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc, 501 ArrayRef<BlockArgument> newBBArgs) { 502 return SmallVector<Value>{transferWrite.getVector()}; 503 }; 504 auto newForOp = 505 replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn); 506 507 // Transfer write has been hoisted, need to update the written vector by 508 // the value yielded by the newForOp. 509 transferWrite.getVectorMutable().assign(newForOp.getResults().back()); 510 511 changed = true; 512 loop.erase(); 513 // Need to interrupt and restart because erasing the loop messes up the 514 // walk. 515 return WalkResult::interrupt(); 516 }); 517 } 518 } 519