1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Dialect/Vector/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Transforms/LoopUtils.h" 25 #include "llvm/ADT/StringRef.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "linalg-hoisting" 29 30 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 void mlir::linalg::hoistViewAllocOps(FuncOp func) { 38 bool changed = true; 39 while (changed) { 40 changed = false; 41 func.walk([&changed](Operation *op) { 42 if (!isa<AllocOp, AllocaOp, DeallocOp>(op)) 43 return; 44 45 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n"); 46 auto loop = dyn_cast<scf::ForOp>(op->getParentOp()); 47 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n"); 48 49 // Only hoist out of immediately enclosing scf::ForOp. 50 if (!loop) 51 return; 52 53 // If any operand is defined inside the loop don't hoist. 54 if (llvm::any_of(op->getOperands(), [&](Value v) { 55 return !loop.isDefinedOutsideOfLoop(v); 56 })) 57 return; 58 59 LLVM_DEBUG(DBGS() << "All operands defined outside \n"); 60 61 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist. 62 Value v; 63 if (op->getNumResults() > 0) { 64 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc"); 65 v = op->getResult(0); 66 } 67 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) { 68 return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner()); 69 })) { 70 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n"); 71 return; 72 } 73 74 // Move AllocOp before the loop. 75 if (isa<AllocOp, AllocaOp>(op)) 76 loop.moveOutOfLoop({op}); 77 else // Move DeallocOp outside of the loop. 78 op->moveAfter(loop); 79 changed = true; 80 }); 81 } 82 } 83 84 /// Look for a transfer_read, in the given tensor uses, accessing the same 85 /// offset as the transfer_write. 86 static vector::TransferReadOp 87 findMatchingTransferRead(vector::TransferWriteOp write, Value srcTensor) { 88 for (Operation *user : srcTensor.getUsers()) { 89 auto read = dyn_cast<vector::TransferReadOp>(user); 90 if (read && read.indices() == write.indices() && 91 read.getVectorType() == write.getVectorType()) { 92 return read; 93 } 94 } 95 return nullptr; 96 } 97 98 /// Check if the chunk of data inserted by the transfer_write in the given 99 /// tensor are read by any other op than the read candidate. 100 static bool tensorChunkAccessedByUnknownOp(vector::TransferWriteOp write, 101 vector::TransferReadOp candidateRead, 102 Value srcTensor) { 103 // Make sure none of the other uses read the part of the tensor modified 104 // by the transfer_write. 105 llvm::SmallVector<Value::use_range, 1> uses; 106 uses.push_back(srcTensor.getUses()); 107 while (!uses.empty()) { 108 for (OpOperand &use : uses.pop_back_val()) { 109 Operation *user = use.getOwner(); 110 // Skip the candidate use, only inspect the "other" uses. 111 if (user == candidateRead.getOperation() || user == write.getOperation()) 112 continue; 113 // Consider all transitive uses through a vector.transfer_write. 114 if (auto writeUser = dyn_cast<vector::TransferWriteOp>(user)) { 115 uses.push_back(writeUser->getResult(0).getUses()); 116 continue; 117 } 118 // Consider all nested uses through an scf::ForOp. We may have 119 // pass-through tensor arguments left from previous level of 120 // hoisting. 121 if (auto forUser = dyn_cast<scf::ForOp>(user)) { 122 Value arg = forUser.getLoopBody().getArgument( 123 use.getOperandNumber() - forUser.getNumControlOperands() + 124 /*iv value*/ 1); 125 uses.push_back(arg.getUses()); 126 continue; 127 } 128 // Follow the use yield as long as it doesn't escape the original 129 // region. 130 scf::YieldOp yieldUser = dyn_cast<scf::YieldOp>(user); 131 if (yieldUser && 132 write->getParentOp()->isAncestor(yieldUser->getParentOp())) { 133 Value ret = yieldUser->getParentOp()->getResult(use.getOperandNumber()); 134 uses.push_back(ret.getUses()); 135 continue; 136 } 137 auto read = dyn_cast<vector::TransferReadOp>(user); 138 if (!read || !isDisjointTransferIndices( 139 cast<VectorTransferOpInterface>(read.getOperation()), 140 cast<VectorTransferOpInterface>(write.getOperation()))) { 141 return true; 142 } 143 } 144 } 145 return false; 146 } 147 148 // To hoist transfer op on tensor the logic can be significantly simplified 149 // compared to the case on buffer. The transformation follows this logic: 150 // 1. Look for transfer_write with a single use from ForOp yield 151 // 2. Check the uses of the matching block argument and look for a transfer_read 152 // with the same indices. 153 // 3. Check that all the other uses of the tensor argument are either disjoint 154 // tensor_read or transfer_write. For transfer_write uses recurse to make sure 155 // the new tensor has the same restrictions on its uses. 156 // 4. Hoist the tensor_read/tensor_write and update the tensor SSA links. 157 // After this transformation the scf.forOp may have unused arguments that can be 158 // remove by the canonicalization pass. 159 void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) { 160 bool changed = true; 161 while (changed) { 162 changed = false; 163 func.walk([&](scf::ForOp forOp) { 164 Operation *yield = forOp.getBody()->getTerminator(); 165 for (auto it : llvm::enumerate(forOp.getRegionIterArgs())) { 166 Value ret = yield->getOperand(it.index()); 167 auto write = ret.getDefiningOp<vector::TransferWriteOp>(); 168 if (!write || !write->hasOneUse()) 169 continue; 170 LLVM_DEBUG(DBGS() << "Candidate write for hoisting: " 171 << *write.getOperation() << "\n"); 172 if (llvm::any_of(write.indices(), [&forOp](Value index) { 173 return !forOp.isDefinedOutsideOfLoop(index); 174 })) 175 continue; 176 // Find a read with the same type and indices. 177 vector::TransferReadOp matchingRead = 178 findMatchingTransferRead(write, it.value()); 179 // Make sure none of the other uses read the part of the tensor modified 180 // by the transfer_write. 181 if (!matchingRead || 182 tensorChunkAccessedByUnknownOp(write, matchingRead, it.value())) 183 continue; 184 185 // Hoist read before. 186 if (failed(forOp.moveOutOfLoop({matchingRead}))) 187 llvm_unreachable( 188 "Unexpected failure to move transfer read out of loop"); 189 // Update the source tensor. 190 matchingRead.sourceMutable().assign(forOp.initArgs()[it.index()]); 191 192 // Hoist write after. 193 write->moveAfter(forOp); 194 yield->setOperand(it.index(), write.source()); 195 196 // Rewrite `loop` with new yields by cloning and erase the original 197 // loop. 198 OpBuilder b(matchingRead); 199 auto newForOp = 200 cloneWithNewYields(b, forOp, matchingRead.vector(), write.vector()); 201 202 // Transfer write has been hoisted, need to update the vector and tensor 203 // source. Replace the result of the loop to use the new tensor created 204 // outside the loop. 205 newForOp.getResult(it.index()).replaceAllUsesWith(write.getResult(0)); 206 write.vectorMutable().assign(newForOp.getResults().back()); 207 write.sourceMutable().assign(newForOp.getResult(it.index())); 208 209 changed = true; 210 forOp.erase(); 211 // Need to interrupt and restart because erasing the loop messes up the 212 // walk. 213 return WalkResult::interrupt(); 214 } 215 return WalkResult::advance(); 216 }); 217 } 218 } 219 220 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 221 bool changed = true; 222 while (changed) { 223 changed = false; 224 225 func.walk([&](vector::TransferReadOp transferRead) { 226 if (!transferRead.getShapedType().isa<MemRefType>()) 227 return WalkResult::advance(); 228 229 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 230 << *transferRead.getOperation() << "\n"); 231 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 232 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 233 << "\n"); 234 if (!loop) 235 return WalkResult::advance(); 236 237 if (failed(moveLoopInvariantCode( 238 cast<LoopLikeOpInterface>(loop.getOperation())))) 239 llvm_unreachable( 240 "Unexpected failure to move invariant code out of loop"); 241 242 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 243 << "\n"); 244 245 llvm::SetVector<Operation *> forwardSlice; 246 getForwardSlice(transferRead, &forwardSlice); 247 248 // Look for the last TransferWriteOp in the forwardSlice of 249 // `transferRead` that operates on the same memref. 250 vector::TransferWriteOp transferWrite; 251 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 252 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 253 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 254 continue; 255 transferWrite = candidateWrite; 256 } 257 258 // All operands of the TransferRead must be defined outside of the loop. 259 for (auto operand : transferRead.getOperands()) 260 if (!loop.isDefinedOutsideOfLoop(operand)) 261 return WalkResult::advance(); 262 263 // Only hoist transfer_read / transfer_write pairs for now. 264 if (!transferWrite) 265 return WalkResult::advance(); 266 267 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 268 << "\n"); 269 270 // Approximate aliasing by checking that: 271 // 1. indices are the same, 272 // 2. no other operations in the loop access the same memref except 273 // for transfer_read/transfer_write accessing statically disjoint 274 // slices. 275 if (transferRead.indices() != transferWrite.indices() && 276 transferRead.getVectorType() == transferWrite.getVectorType()) 277 return WalkResult::advance(); 278 279 // TODO: may want to memoize this information for performance but it 280 // likely gets invalidated often. 281 DominanceInfo dom(loop); 282 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 283 return WalkResult::advance(); 284 for (auto &use : transferRead.source().getUses()) { 285 if (!dom.properlyDominates(loop, use.getOwner())) 286 continue; 287 if (use.getOwner() == transferRead.getOperation() || 288 use.getOwner() == transferWrite.getOperation()) 289 continue; 290 if (auto transferWriteUse = 291 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 292 if (!isDisjointTransferSet( 293 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 294 cast<VectorTransferOpInterface>( 295 transferWriteUse.getOperation()))) 296 return WalkResult::advance(); 297 } else if (auto transferReadUse = 298 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 299 if (!isDisjointTransferSet( 300 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 301 cast<VectorTransferOpInterface>( 302 transferReadUse.getOperation()))) 303 return WalkResult::advance(); 304 } else { 305 // Unknown use, we cannot prove that it doesn't alias with the 306 // transferRead/transferWrite operations. 307 return WalkResult::advance(); 308 } 309 } 310 311 // Hoist read before. 312 if (failed(loop.moveOutOfLoop({transferRead}))) 313 llvm_unreachable( 314 "Unexpected failure to move transfer read out of loop"); 315 316 // Hoist write after. 317 transferWrite->moveAfter(loop); 318 319 // Rewrite `loop` with new yields by cloning and erase the original loop. 320 OpBuilder b(transferRead); 321 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 322 transferWrite.vector()); 323 324 // Transfer write has been hoisted, need to update the written value to 325 // the value yielded by the newForOp. 326 transferWrite.vector().replaceAllUsesWith( 327 newForOp.getResults().take_back()[0]); 328 329 changed = true; 330 loop.erase(); 331 // Need to interrupt and restart because erasing the loop messes up the 332 // walk. 333 return WalkResult::interrupt(); 334 }); 335 } 336 } 337