1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineValueMap.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/SCF/SCF.h" 21 #include "mlir/Dialect/SCF/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Tensor/IR/Tensor.h" 24 #include "mlir/Dialect/Vector/IR/VectorOps.h" 25 #include "mlir/Dialect/Vector/Utils/VectorUtils.h" 26 #include "mlir/IR/BuiltinOps.h" 27 #include "mlir/IR/Dominance.h" 28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 29 #include "llvm/ADT/StringRef.h" 30 #include "llvm/Support/Debug.h" 31 32 using llvm::dbgs; 33 34 #define DEBUG_TYPE "linalg-hoisting" 35 36 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 namespace { 42 /// Represents a unit of hoistable TransferWriteOp. This may comprise other 43 /// instructions that need to be hoisted too. 44 struct HoistableWrite { 45 vector::TransferWriteOp transferWriteOp; 46 tensor::InsertSliceOp insertSliceOp; 47 }; 48 /// Represents a unit of hoistable TransferReadOp. This may comprise other 49 /// instructions that need to be hoisted too. 50 struct HoistableRead { 51 vector::TransferReadOp transferReadOp; 52 tensor::ExtractSliceOp extractSliceOp; 53 }; 54 } // namespace 55 56 /// Return true if op1 and op2 are the same constant or the same SSA value. 57 static bool isEqualOffsetSizeOrStride(OpFoldResult op1, OpFoldResult op2) { 58 auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> { 59 Attribute attr = ofr.dyn_cast<Attribute>(); 60 // Note: isa+cast-like pattern allows writing the condition below as 1 line. 61 if (!attr && ofr.get<Value>().getDefiningOp<arith::ConstantOp>()) 62 attr = ofr.get<Value>().getDefiningOp<arith::ConstantOp>().getValue(); 63 if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>()) 64 return intAttr.getValue().getSExtValue(); 65 return llvm::None; 66 }; 67 auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2); 68 if (cst1 && cst2 && *cst1 == *cst2) 69 return true; 70 auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>(); 71 return v1 && v2 && v1 == v2; 72 } 73 74 /// Return true is all offsets, sizes and strides are equal. 75 static bool sameOffsetsSizesAndStrides(tensor::ExtractSliceOp s, 76 tensor::InsertSliceOp si) { 77 if (s.static_offsets().size() != si.static_offsets().size()) 78 return false; 79 if (s.static_sizes().size() != si.static_sizes().size()) 80 return false; 81 if (s.static_strides().size() != si.static_strides().size()) 82 return false; 83 for (auto it : llvm::zip(s.getMixedOffsets(), si.getMixedOffsets())) 84 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 85 return false; 86 for (auto it : llvm::zip(s.getMixedSizes(), si.getMixedSizes())) 87 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 88 return false; 89 for (auto it : llvm::zip(s.getMixedStrides(), si.getMixedStrides())) 90 if (!isEqualOffsetSizeOrStride(std::get<0>(it), std::get<1>(it))) 91 return false; 92 return true; 93 } 94 95 /// Look for a HoistableRead, in the given tensor uses, accessing the same 96 /// offset as the HoistableWrite. 97 static HoistableRead findMatchingTransferRead(HoistableWrite write, 98 Value srcTensor) { 99 assert(write.transferWriteOp && 100 "expected hoistable write to have a .transfer_write"); 101 102 LLVM_DEBUG(DBGS() << "findMatchingTransferRead for: " 103 << *write.transferWriteOp.getOperation() << "\n"); 104 if (write.insertSliceOp) 105 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inserSliceOp: " 106 << *write.insertSliceOp.getOperation() << "\n"); 107 108 for (Operation *user : srcTensor.getUsers()) { 109 LLVM_DEBUG(DBGS() << "findMatchingTransferRead inspect user: " << *user 110 << "\n"); 111 112 // If HoistableWrite involves a InsertSliceOp, we need to find a 113 // matching ExtractSliceOp. 114 tensor::ExtractSliceOp sliceOp; 115 Operation *maybeTransferReadUser = user; 116 if (write.insertSliceOp) { 117 sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); 118 if (!sliceOp || sliceOp.getResult().getType() != 119 write.insertSliceOp.source().getType()) 120 continue; 121 122 LLVM_DEBUG(DBGS() << "check whether sameOffsetsSizesAndStrides: " 123 << *sliceOp << " vs " << *write.insertSliceOp << "\n"); 124 if (!sameOffsetsSizesAndStrides(sliceOp, write.insertSliceOp)) 125 continue; 126 127 LLVM_DEBUG(DBGS() << "sameOffsetsSizesAndStrides: SUCCESS\n"); 128 // If we got here, sliceOp is hoistable iff it has exactly 2 uses: 129 // 1. the transfer_write we want to hoist. 130 // 2. a matching transfer_read. 131 // Anything else, we skip. 132 bool skip = false; 133 Operation *otherUser = nullptr; 134 for (Operation *u : sliceOp->getUsers()) { 135 if (u == write.transferWriteOp) 136 continue; 137 if (otherUser) { 138 skip = true; 139 break; 140 } 141 otherUser = u; 142 } 143 if (skip || !otherUser) 144 continue; 145 maybeTransferReadUser = otherUser; 146 } 147 148 LLVM_DEBUG(DBGS() << "maybeTransferReadUser: " << *maybeTransferReadUser 149 << "\n"); 150 auto read = dyn_cast<vector::TransferReadOp>(maybeTransferReadUser); 151 if (read && read.indices() == write.transferWriteOp.indices() && 152 read.getVectorType() == write.transferWriteOp.getVectorType()) 153 return HoistableRead{read, sliceOp}; 154 } 155 return HoistableRead(); 156 } 157 158 /// Check if the chunk of data inserted by the HoistableWrite are read by any 159 /// other op than the HoistableRead candidate. 160 static bool tensorChunkAccessedByUnknownOp(HoistableWrite write, 161 HoistableRead candidateRead, 162 BlockArgument tensorArg) { 163 // Make sure none of the other uses read the part of the tensor modified 164 // by the transfer_write. 165 llvm::SmallVector<Value::use_range, 1> uses; 166 uses.push_back(tensorArg.getUses()); 167 while (!uses.empty()) { 168 for (OpOperand &use : uses.pop_back_val()) { 169 Operation *user = use.getOwner(); 170 // Skip the candidate use, only inspect the "other" uses. 171 if (user == candidateRead.transferReadOp || 172 user == candidateRead.extractSliceOp || 173 user == write.transferWriteOp || user == write.insertSliceOp) 174 continue; 175 // Consider all transitive uses through a extract_slice / insert_slice. 176 // TODO: atm we just bail because a stronger analysis is needed for these 177 // cases. 178 if (isa<tensor::ExtractSliceOp, tensor::InsertSliceOp>(user)) 179 return true; 180 // Consider all transitive uses through a vector.transfer_write. 181 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 182 uses.push_back(writeUser->getResult(0).getUses()); 183 continue; 184 } 185 // Consider all nested uses through an scf::ForOp. We may have 186 // pass-through tensor arguments left from previous level of 187 // hoisting. 188 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 189 Value arg = forUser.getLoopBody().getArgument( 190 use.getOperandNumber() - forUser.getNumControlOperands() + 191 /*iv value*/ 1); 192 uses.push_back(arg.getUses()); 193 continue; 194 } 195 // Follow the use yield as long as it doesn't escape the original 196 // region. 197 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 198 if (yieldUser && write.transferWriteOp->getParentOp()->isAncestor( 199 yieldUser->getParentOp())) { 200 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 201 uses.push_back(ret.getUses()); 202 continue; 203 } 204 auto read = dyn_cast<vector::TransferReadOp>(user); 205 if (!read || !vector::isDisjointTransferIndices( 206 cast<VectorTransferOpInterface>(read.getOperation()), 207 cast<VectorTransferOpInterface>( 208 write.transferWriteOp.getOperation()))) { 209 return true; 210 } 211 } 212 } 213 return false; 214 } 215 216 /// Return the `forOp`-invariant HoistableWrite that produces `yieldOperand`. 217 /// Return the null HoistableWrite() if it is not comprised of a 218 /// vector.transfer_write + optional insert_slice or if any of the indexings 219 /// is `forOp`-dependent. 220 static HoistableWrite 221 getLoopInvariantTransferWriteOpDefining(scf::ForOp forOp, 222 OpOperand &yieldOperand) { 223 Value v = yieldOperand.get(); 224 if (auto write = v.getDefiningOp<vector::TransferWriteOp>()) { 225 // Indexing must not depend on `forOp`. 226 for (Value operand : write.indices()) 227 if (!forOp.isDefinedOutsideOfLoop(operand)) 228 return HoistableWrite(); 229 230 return HoistableWrite{write, nullptr}; 231 } 232 233 if (auto insertSliceOp = v.getDefiningOp<tensor::InsertSliceOp>()) { 234 // Inserted slice must come from vector.transfer_write. 235 auto write = 236 insertSliceOp.source().getDefiningOp<vector::TransferWriteOp>(); 237 if (!write) 238 return HoistableWrite(); 239 240 // Tensor inserted into must be a BBArg at position matching yieldOperand's. 241 auto bbArg = insertSliceOp.dest().dyn_cast<BlockArgument>(); 242 if (!bbArg || bbArg.getOwner()->getParentOp() != forOp || 243 bbArg.getArgNumber() != /*num iv=*/1 + yieldOperand.getOperandNumber()) 244 return HoistableWrite(); 245 246 // Indexing inserted into must not depend on `forOp`. 247 for (Value operand : insertSliceOp->getOperands().drop_front( 248 tensor::InsertSliceOp::getOffsetSizeAndStrideStartOperandIndex())) 249 if (!forOp.isDefinedOutsideOfLoop(operand)) 250 return HoistableWrite(); 251 252 return HoistableWrite{write, insertSliceOp}; 253 } 254 255 return HoistableWrite(); 256 } 257 258 /// Mechanical hoisting of a matching HoistableRead / HoistableWrite pair. 259 static void hoistReadWrite(HoistableRead read, HoistableWrite write, 260 BlockArgument tensorBBArg) { 261 scf::ForOp forOp = cast<scf::ForOp>(tensorBBArg.getOwner()->getParentOp()); 262 assert(read.transferReadOp && write.transferWriteOp && 263 "expected transfer_read and transfer_write ops to be set"); 264 assert(((read.extractSliceOp && write.insertSliceOp) || 265 (!read.extractSliceOp && !write.insertSliceOp)) && 266 "expected matching extract_slice / insert_slice"); 267 LLVM_DEBUG(DBGS() << "In forOp:\n" 268 << *forOp.getOperation() 269 << "\nHoist: " << *read.transferReadOp.getOperation() 270 << "\nHoist: " << *write.transferWriteOp.getOperation() 271 << "\nInvolving: " << tensorBBArg << "\n"); 272 273 // If a read slice is present, hoist it. 274 if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp}))) 275 llvm_unreachable("Unexpected failure moving extract_slice out of loop"); 276 277 // Hoist the transfer_read op. 278 if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) 279 llvm_unreachable("Unexpected failure moving transfer read out of loop"); 280 281 // TODO: don't hardcode /*numIvs=*/1. 282 assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); 283 unsigned initArgNumber = tensorBBArg.getArgNumber() - /*numIvs=*/1; 284 285 // Update the source tensor. 286 if (read.extractSliceOp) 287 read.extractSliceOp.sourceMutable().assign( 288 forOp.getInitArgs()[initArgNumber]); 289 else 290 read.transferReadOp.sourceMutable().assign( 291 forOp.getInitArgs()[initArgNumber]); 292 293 // Hoist write after. 294 if (write.insertSliceOp) 295 write.insertSliceOp->moveAfter(forOp); 296 write.transferWriteOp->moveAfter(forOp); 297 298 // Update the yield. 299 auto yieldOp = cast<scf::YieldOp>(forOp.getRegion().front().getTerminator()); 300 if (write.insertSliceOp) 301 yieldOp->setOperand(initArgNumber, write.insertSliceOp.dest()); 302 else 303 yieldOp->setOperand(initArgNumber, write.transferWriteOp.source()); 304 305 // Rewrite `loop` with additional new yields. 306 OpBuilder b(read.transferReadOp); 307 auto newForOp = cloneWithNewYields(b, forOp, read.transferReadOp.vector(), 308 write.transferWriteOp.vector()); 309 // Transfer write has been hoisted, need to update the vector and tensor 310 // source. Replace the result of the loop to use the new tensor created 311 // outside the loop. 312 // Depending on whether a insert_slice is present or not, it carries the 313 // update on the tensor operands. 314 if (write.insertSliceOp) { 315 newForOp.getResult(initArgNumber) 316 .replaceAllUsesWith(write.insertSliceOp.getResult()); 317 write.transferWriteOp.sourceMutable().assign(read.extractSliceOp.result()); 318 write.insertSliceOp.destMutable().assign(read.extractSliceOp.source()); 319 } else { 320 newForOp.getResult(initArgNumber) 321 .replaceAllUsesWith(write.transferWriteOp.getResult()); 322 write.transferWriteOp.sourceMutable().assign( 323 newForOp.getResult(initArgNumber)); 324 } 325 326 // Always update with the newly yield tensor and vector. 327 write.transferWriteOp.vectorMutable().assign(newForOp.getResults().back()); 328 } 329 330 // To hoist transfer op on tensor the logic can be significantly simplified 331 // compared to the case on buffer. The transformation follows this logic: 332 // 1. Look for transfer_write with a single use from ForOp yield 333 // 2. Check the uses of the matching block argument and look for a transfer_read 334 // with the same indices. 335 // 3. Check that all the other uses of the tensor argument are either disjoint 336 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 337 // the new tensor has the same restrictions on its uses. 338 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 339 // After this transformation the scf.forOp may have unused arguments that can be 340 // remove by the canonicalization pass. 341 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 342 bool changed = true; 343 while (changed) { 344 changed = false; 345 func.walk([&](scf::ForOp forOp) { 346 Operation *yield = forOp.getBody()->getTerminator(); 347 for (const auto &it : llvm::enumerate(forOp.getRegionIterArgs())) { 348 OpOperand &ret = yield->getOpOperand(it.index()); 349 HoistableWrite write = 350 getLoopInvariantTransferWriteOpDefining(forOp, ret); 351 if (!write.transferWriteOp || !write.transferWriteOp->hasOneUse()) 352 continue; 353 LLVM_DEBUG(dbgs() << "\n"; 354 DBGS() << "Candidate write for hoisting: " 355 << *write.transferWriteOp.getOperation() << "\n"); 356 if (write.insertSliceOp) 357 LLVM_DEBUG(DBGS() << "Candidate insert_slice for hoisting: " 358 << *write.insertSliceOp.getOperation() << "\n"); 359 if (llvm::any_of(write.transferWriteOp.indices(), 360 [&forOp](Value index) { 361 return !forOp.isDefinedOutsideOfLoop(index); 362 })) 363 continue; 364 // Find a read with the same type and indices. 365 HoistableRead matchingRead = 366 findMatchingTransferRead(write, it.value()); 367 // Make sure none of the other uses read the part of the tensor modified 368 // by the transfer_write. 369 if (!matchingRead.transferReadOp || 370 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 371 continue; 372 373 LLVM_DEBUG(DBGS() << "Start hoisting\n"); 374 hoistReadWrite(matchingRead, write, it.value()); 375 changed = true; 376 forOp.erase(); 377 378 // Need to interrupt and restart: erasing the loop messes up the walk. 379 return WalkResult::interrupt(); 380 } 381 return WalkResult::advance(); 382 }); 383 // Apply canonicalization so the newForOp + yield folds immediately, thus 384 // cleaning up the IR and potentially enabling more hoisting. 385 if (changed) { 386 RewritePatternSet patterns(func->getContext()); 387 scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext()); 388 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 389 } 390 } 391 } 392 393 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 394 bool changed = true; 395 while (changed) { 396 changed = false; 397 // First move loop invariant ops outside of their loop. This needs to be 398 // done before as we cannot move ops without interputing the function walk. 399 func.walk([&](LoopLikeOpInterface loopLike) { 400 if (failed(moveLoopInvariantCode(loopLike))) 401 llvm_unreachable( 402 "Unexpected failure to move invariant code out of loop"); 403 }); 404 405 func.walk([&](vector::TransferReadOp transferRead) { 406 if (!transferRead.getShapedType().isa<MemRefType>()) 407 return WalkResult::advance(); 408 409 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 410 << *transferRead.getOperation() << "\n"); 411 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 412 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 413 << "\n"); 414 if (!loop) 415 return WalkResult::advance(); 416 417 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 418 << "\n"); 419 420 SetVector<Operation *> forwardSlice; 421 getForwardSlice(transferRead.getOperation(), &forwardSlice); 422 423 // Look for the last TransferWriteOp in the forwardSlice of 424 // `transferRead` that operates on the same memref. 425 vector::TransferWriteOp transferWrite; 426 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 427 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 428 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 429 continue; 430 transferWrite = candidateWrite; 431 } 432 433 // All operands of the TransferRead must be defined outside of the loop. 434 for (auto operand : transferRead.getOperands()) 435 if (!loop.isDefinedOutsideOfLoop(operand)) 436 return WalkResult::advance(); 437 438 // Only hoist transfer_read / transfer_write pairs for now. 439 if (!transferWrite) 440 return WalkResult::advance(); 441 442 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 443 << "\n"); 444 445 // Approximate aliasing by checking that: 446 // 1. indices are the same, 447 // 2. no other operations in the loop access the same memref except 448 // for transfer_read/transfer_write accessing statically disjoint 449 // slices. 450 if (transferRead.indices() != transferWrite.indices() && 451 transferRead.getVectorType() == transferWrite.getVectorType()) 452 return WalkResult::advance(); 453 454 // TODO: may want to memoize this information for performance but it 455 // likely gets invalidated often. 456 DominanceInfo dom(loop); 457 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 458 return WalkResult::advance(); 459 for (auto &use : transferRead.source().getUses()) { 460 if (!loop->isAncestor(use.getOwner())) 461 continue; 462 if (use.getOwner() == transferRead.getOperation() || 463 use.getOwner() == transferWrite.getOperation()) 464 continue; 465 if (auto transferWriteUse = 466 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 467 if (!vector::isDisjointTransferSet( 468 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 469 cast<VectorTransferOpInterface>( 470 transferWriteUse.getOperation()))) 471 return WalkResult::advance(); 472 } else if (auto transferReadUse = 473 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 474 if (!vector::isDisjointTransferSet( 475 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 476 cast<VectorTransferOpInterface>( 477 transferReadUse.getOperation()))) 478 return WalkResult::advance(); 479 } else { 480 // Unknown use, we cannot prove that it doesn't alias with the 481 // transferRead/transferWrite operations. 482 return WalkResult::advance(); 483 } 484 } 485 486 // Hoist read before. 487 if (failed(loop.moveOutOfLoop({transferRead}))) 488 llvm_unreachable( 489 "Unexpected failure to move transfer read out of loop"); 490 491 // Hoist write after. 492 transferWrite->moveAfter(loop); 493 494 // Rewrite `loop` with new yields by cloning and erase the original loop. 495 OpBuilder b(transferRead); 496 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 497 transferWrite.vector()); 498 499 // Transfer write has been hoisted, need to update the written value to 500 // the value yielded by the newForOp. 501 transferWrite.vector().replaceAllUsesWith( 502 newForOp.getResults().take_back()[0]); 503 504 changed = true; 505 loop.erase(); 506 // Need to interrupt and restart because erasing the loop messes up the 507 // walk. 508 return WalkResult::interrupt(); 509 }); 510 } 511 } 512