1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SCF/Utils/Utils.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/Dialect/Vector/IR/VectorOps.h" 24 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 25 #include "mlir/IR/BuiltinOps.h" 26 #include "mlir/IR/Dominance.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/StringRef.h" 29 #include "llvm/Support/Debug.h" 30 31 using llvm::dbgs; 32 33 #define DEBUG_TYPE "linalg-hoisting" 34 35 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 36 37 using namespace mlir; 38 using namespace mlir::linalg; 39 40 namespace { 41 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 42 /// instructions that need to be hoisted too. 43 struct HoistableWrite { 44 vector::TransferWriteOp transferWriteOp; 45 tensor::InsertSliceOp insertSliceOp; 46 }; 47 /// Represents a unit of hoistable TransferReadOp. This may comprise other 48 /// instructions that need to be hoisted too. 49 struct HoistableRead { 50 vector::TransferReadOp transferReadOp; 51 tensor::ExtractSliceOp extractSliceOp; 52 }; 53 } // namespace 54 55 /// Return true if op1 and op2 are the same constant or the same SSA value. 56 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 57 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 58 Attribute attr = ofr.dyn_cast<Attribute>(); 59 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 60 if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>()) 61 attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue(); 62 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 63 return intAttr.getValue().getSExtValue(); 64 return llvm::None; 65 }; 66 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 67 if (cst1 && cst2 && *cst1 == *cst2) 68 return true; 69 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 70 return v1 && v2 && v1 == v2; 71 } 72 73 /// Return true is all offsets, sizes and strides are equal. 74 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, 75 tensor::InsertSliceOp si) { 76 if (s.static_offsets().size() != si.static_offsets().size()) 77 return false; 78 if (s.static_sizes().size() != si.static_sizes().size()) 79 return false; 80 if (s.static_strides().size() != si.static_strides().size()) 81 return false; 82 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 83 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 84 return false; 85 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 86 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 87 return false; 88 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 89 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 90 return false; 91 return true; 92 } 93 94 /// Look for a HoistableRead, in the given tensor uses, accessing the same 95 /// offset as the HoistableWrite. 96 static HoistableRead findMatchingTransferRead(HoistableWrite write, 97 Value srcTensor) { 98 assert(write.transferWriteOp && 99 "expected hoistable write to have a .transfer_write"); 100 101 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 102 << *write.transferWriteOp.getOperation() << "\n"); 103 if (write.insertSliceOp) 104 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " 105 << *write.insertSliceOp.getOperation() << "\n"); 106 107 for (Operation *user : srcTensor.getUsers()) { 108 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 109 << "\n"); 110 111 // If HoistableWrite involves a InsertSliceOp, we need to find a 112 // matching ExtractSliceOp. 113 tensor::ExtractSliceOp sliceOp; 114 Operation *maybeTransferReadUser = user; 115 if (write.insertSliceOp) { 116 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 117 if (!sliceOp || sliceOp.getResult().getType() != 118 write.insertSliceOp.source().getType()) 119 continue; 120 121 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 122 << *sliceOp << " vs " << *write.insertSliceOp << "\n"); 123 if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) 124 continue; 125 126 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 127 // If we got here, sliceOp is hoistable iff it has exactly 2 uses: 128 // 1. the transfer_write we want to hoist. 129 // 2. a matching transfer_read. 130 // Anything else, we skip. 131 bool skip = false; 132 Operation *otherUser = nullptr; 133 for (Operation *u : sliceOp->getUsers()) { 134 if (u == write.transferWriteOp) 135 continue; 136 if (otherUser) { 137 skip = true; 138 break; 139 } 140 otherUser = u; 141 } 142 if (skip || !otherUser) 143 continue; 144 maybeTransferReadUser = otherUser; 145 } 146 147 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 148 << "\n"); 149 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 150 if (read && read.indices() == write.transferWriteOp.indices() && 151 read.getVectorType() == write.transferWriteOp.getVectorType()) 152 return HoistableRead{read, sliceOp}; 153 } 154 return HoistableRead(); 155 } 156 157 /// Check if the chunk of data inserted by the HoistableWrite are read by any 158 /// other op than the HoistableRead candidate. 159 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 160 HoistableRead candidateRead, 161 BlockArgument tensorArg) { 162 // Make sure none of the other uses read the part of the tensor modified 163 // by the transfer_write. 164 llvm::SmallVector<Value::use_range, 1> uses; 165 uses.push_back(tensorArg.getUses()); 166 while (!uses.empty()) { 167 for (OpOperand &use : uses.pop_back_val()) { 168 Operation *user = use.getOwner(); 169 // Skip the candidate use, only inspect the "other" uses. 170 if (user == candidateRead.transferReadOp || 171 user == candidateRead.extractSliceOp || 172 user == write.transferWriteOp || user == write.insertSliceOp) 173 continue; 174 // Consider all transitive uses through a extract_slice / insert_slice. 175 // TODO: atm we just bail because a stronger analysis is needed for these 176 // cases. 177 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user)) 178 return true; 179 // Consider all transitive uses through a vector.transfer_write. 180 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 181 uses.push_back(writeUser->getResult(0).getUses()); 182 continue; 183 } 184 // Consider all nested uses through an scf::ForOp. We may have 185 // pass-through tensor arguments left from previous level of 186 // hoisting. 187 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 188 Value arg = forUser.getLoopBody().getArgument( 189 use.getOperandNumber() - forUser.getNumControlOperands() + 190 /*iv value*/ 1); 191 uses.push_back(arg.getUses()); 192 continue; 193 } 194 // Follow the use yield as long as it doesn't escape the original 195 // region. 196 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 197 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 198 yieldUser->getParentOp())) { 199 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 200 uses.push_back(ret.getUses()); 201 continue; 202 } 203 auto read = dyn_cast<vector::TransferReadOp>(user); 204 if (!read || !vector::isDisjointTransferIndices( 205 cast<VectorTransferOpInterface>(read.getOperation()), 206 cast<VectorTransferOpInterface>( 207 write.transferWriteOp.getOperation()))) { 208 return true; 209 } 210 } 211 } 212 return false; 213 } 214 215 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 216 /// Return the null HoistableWrite() if it is not comprised of a 217 /// vector.transfer_write + optional insert_slice or if any of the indexings 218 /// is `forOp`-dependent. 219 static HoistableWrite 220 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 221 OpOperand &yieldOperand) { 222 Value v = yieldOperand.get(); 223 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 224 // Indexing must not depend on `forOp`. 225 for (Value operand : write.indices()) 226 if (!forOp.isDefinedOutsideOfLoop(operand)) 227 return HoistableWrite(); 228 229 return HoistableWrite{write, nullptr}; 230 } 231 232 if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) { 233 // Inserted slice must come from vector.transfer_write. 234 auto write = 235 insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>(); 236 if (!write) 237 return HoistableWrite(); 238 239 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 240 auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>(); 241 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 242 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 243 return HoistableWrite(); 244 245 // Indexing inserted into must not depend on `forOp`. 246 for (Value operand : insertSliceOp->getOperands().drop_front( 247 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) 248 if (!forOp.isDefinedOutsideOfLoop(operand)) 249 return HoistableWrite(); 250 251 return HoistableWrite{write, insertSliceOp}; 252 } 253 254 return HoistableWrite(); 255 } 256 257 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 258 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 259 BlockArgument tensorBBArg) { 260 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 261 assert(read.transferReadOp && write.transferWriteOp && 262 "expected transfer_read and transfer_write ops to be set"); 263 assert(((read.extractSliceOp && write.insertSliceOp) || 264 (!read.extractSliceOp && !write.insertSliceOp)) && 265 "expected matching extract_slice / insert_slice"); 266 LLVM_DEBUG(DBGS() << "In forOp:\n" 267 << *forOp.getOperation() 268 << "\nHoist: " << *read.transferReadOp.getOperation() 269 << "\nHoist: " << *write.transferWriteOp.getOperation() 270 << "\nInvolving: " << tensorBBArg << "\n"); 271 272 // If a read slice is present, hoist it. 273 if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp}))) 274 llvm_unreachable("Unexpected failure moving extract_slice out of loop"); 275 276 // Hoist the transfer_read op. 277 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 278 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 279 280 // TODO: don't hardcode /*numIvs=*/1. 281 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 282 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 283 284 // Update the source tensor. 285 if (read.extractSliceOp) 286 read.extractSliceOp.sourceMutable().assign( 287 forOp.getInitArgs()[initArgNumber]); 288 else 289 read.transferReadOp.sourceMutable().assign( 290 forOp.getInitArgs()[initArgNumber]); 291 292 // Hoist write after. 293 if (write.insertSliceOp) 294 write.insertSliceOp->moveAfter(forOp); 295 write.transferWriteOp->moveAfter(forOp); 296 297 // Update the yield. 298 auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator()); 299 if (write.insertSliceOp) 300 yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest()); 301 else 302 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 303 304 // Rewrite `loop` with additional new yields. 305 OpBuilder b(read.transferReadOp); 306 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 307 write.transferWriteOp.vector()); 308 // Transfer write has been hoisted, need to update the vector and tensor 309 // source. Replace the result of the loop to use the new tensor created 310 // outside the loop. 311 // Depending on whether a insert_slice is present or not, it carries the 312 // update on the tensor operands. 313 if (write.insertSliceOp) { 314 newForOp.getResult(initArgNumber) 315 .replaceAllUsesWith(write.insertSliceOp.getResult()); 316 write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result()); 317 write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); 318 } else { 319 newForOp.getResult(initArgNumber) 320 .replaceAllUsesWith(write.transferWriteOp.getResult()); 321 write.transferWriteOp.sourceMutable().assign( 322 newForOp.getResult(initArgNumber)); 323 } 324 325 // Always update with the newly yield tensor and vector. 326 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 327 } 328 329 // To hoist transfer op on tensor the logic can be significantly simplified 330 // compared to the case on buffer. The transformation follows this logic: 331 // 1. Look for transfer_write with a single use from ForOp yield 332 // 2. Check the uses of the matching block argument and look for a transfer_read 333 // with the same indices. 334 // 3. Check that all the other uses of the tensor argument are either disjoint 335 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 336 // the new tensor has the same restrictions on its uses. 337 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 338 // After this transformation the scf.forOp may have unused arguments that can be 339 // remove by the canonicalization pass. 340 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 341 bool changed = true; 342 while (changed) { 343 changed = false; 344 func.walk([&](scf::ForOp forOp) { 345 Operation *yield = forOp.getBody()->getTerminator(); 346 for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { 347 OpOperand &ret = yield->getOpOperand(it.index()); 348 HoistableWrite write = 349 getLoopInvariantTransferWriteOpDefining(forOp, ret); 350 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 351 continue; 352 LLVM_DEBUG(dbgs() << "\n"; 353 DBGS() << "Candidate write for hoisting: " 354 << *write.transferWriteOp.getOperation() << "\n"); 355 if (write.insertSliceOp) 356 LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " 357 << *write.insertSliceOp.getOperation() << "\n"); 358 if (llvm::any_of(write.transferWriteOp.indices(), 359 [&forOp](Value index) { 360 return !forOp.isDefinedOutsideOfLoop(index); 361 })) 362 continue; 363 // Find a read with the same type and indices. 364 HoistableRead matchingRead = 365 findMatchingTransferRead(write, it.value()); 366 // Make sure none of the other uses read the part of the tensor modified 367 // by the transfer_write. 368 if (!matchingRead.transferReadOp || 369 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 370 continue; 371 372 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 373 hoistReadWrite(matchingRead, write, it.value()); 374 changed = true; 375 forOp.erase(); 376 377 // Need to interrupt and restart: erasing the loop messes up the walk. 378 return WalkResult::interrupt(); 379 } 380 return WalkResult::advance(); 381 }); 382 // Apply canonicalization so the newForOp + yield folds immediately, thus 383 // cleaning up the IR and potentially enabling more hoisting. 384 if (changed) { 385 RewritePatternSet patterns(func->getContext()); 386 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 387 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 388 } 389 } 390 } 391 392 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 393 bool changed = true; 394 while (changed) { 395 changed = false; 396 // First move loop invariant ops outside of their loop. This needs to be 397 // done before as we cannot move ops without interputing the function walk. 398 func.walk([&](LoopLikeOpInterface loopLike) { 399 if (failed(moveLoopInvariantCode(loopLike))) 400 llvm_unreachable( 401 "Unexpected failure to move invariant code out of loop"); 402 }); 403 404 func.walk([&](vector::TransferReadOp transferRead) { 405 if (!transferRead.getShapedType().isa<MemRefType>()) 406 return WalkResult::advance(); 407 408 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 409 << *transferRead.getOperation() << "\n"); 410 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 411 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 412 << "\n"); 413 if (!loop) 414 return WalkResult::advance(); 415 416 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 417 << "\n"); 418 419 SetVector<Operation *> forwardSlice; 420 getForwardSlice(transferRead.getOperation(), &forwardSlice); 421 422 // Look for the last TransferWriteOp in the forwardSlice of 423 // `transferRead` that operates on the same memref. 424 vector::TransferWriteOp transferWrite; 425 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 426 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 427 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 428 continue; 429 transferWrite = candidateWrite; 430 } 431 432 // All operands of the TransferRead must be defined outside of the loop. 433 for (auto operand : transferRead.getOperands()) 434 if (!loop.isDefinedOutsideOfLoop(operand)) 435 return WalkResult::advance(); 436 437 // Only hoist transfer_read / transfer_write pairs for now. 438 if (!transferWrite) 439 return WalkResult::advance(); 440 441 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 442 << "\n"); 443 444 // Approximate aliasing by checking that: 445 // 1. indices are the same, 446 // 2. no other operations in the loop access the same memref except 447 // for transfer_read/transfer_write accessing statically disjoint 448 // slices. 449 if (transferRead.indices() != transferWrite.indices() && 450 transferRead.getVectorType() == transferWrite.getVectorType()) 451 return WalkResult::advance(); 452 453 // TODO: may want to memoize this information for performance but it 454 // likely gets invalidated often. 455 DominanceInfo dom(loop); 456 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 457 return WalkResult::advance(); 458 for (auto &use : transferRead.source().getUses()) { 459 if (!loop->isAncestor(use.getOwner())) 460 continue; 461 if (use.getOwner() == transferRead.getOperation() || 462 use.getOwner() == transferWrite.getOperation()) 463 continue; 464 if (auto transferWriteUse = 465 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 466 if (!vector::isDisjointTransferSet( 467 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 468 cast<VectorTransferOpInterface>( 469 transferWriteUse.getOperation()))) 470 return WalkResult::advance(); 471 } else if (auto transferReadUse = 472 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 473 if (!vector::isDisjointTransferSet( 474 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 475 cast<VectorTransferOpInterface>( 476 transferReadUse.getOperation()))) 477 return WalkResult::advance(); 478 } else { 479 // Unknown use, we cannot prove that it doesn't alias with the 480 // transferRead/transferWrite operations. 481 return WalkResult::advance(); 482 } 483 } 484 485 // Hoist read before. 486 if (failed(loop.moveOutOfLoop({transferRead}))) 487 llvm_unreachable( 488 "Unexpected failure to move transfer read out of loop"); 489 490 // Hoist write after. 491 transferWrite->moveAfter(loop); 492 493 // Rewrite `loop` with new yields by cloning and erase the original loop. 494 OpBuilder b(transferRead); 495 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 496 transferWrite.vector()); 497 498 // Transfer write has been hoisted, need to update the written value to 499 // the value yielded by the newForOp. 500 transferWrite.vector().replaceAllUsesWith( 501 newForOp.getResults().take_back()[0]); 502 503 changed = true; 504 loop.erase(); 505 // Need to interrupt and restart because erasing the loop messes up the 506 // walk. 507 return WalkResult::interrupt(); 508 }); 509 } 510 } 511