1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Dialect/Vector/VectorUtils.h" 22 #include "mlir/IR/BuiltinOps.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/Transforms/LoopUtils.h" 25 #include "llvm/ADT/StringRef.h" 26 #include "llvm/Support/Debug.h" 27 28 #define DEBUG_TYPE "linalg-hoisting" 29 30 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 31 32 using namespace mlir; 33 using namespace mlir::linalg; 34 35 using llvm::dbgs; 36 37 void mlir::linalg::hoistViewAllocOps(FuncOp func) { 38 bool changed = true; 39 while (changed) { 40 changed = false; 41 func.walk([&changed](Operation *op) { 42 if (!isa<AllocOp, AllocaOp, DeallocOp>(op)) 43 return; 44 45 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n"); 46 auto loop = dyn_cast<scf::ForOp>(op->getParentOp()); 47 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n"); 48 49 // Only hoist out of immediately enclosing scf::ForOp. 50 if (!loop) 51 return; 52 53 // If any operand is defined inside the loop don't hoist. 54 if (llvm::any_of(op->getOperands(), [&](Value v) { 55 return !loop.isDefinedOutsideOfLoop(v); 56 })) 57 return; 58 59 LLVM_DEBUG(DBGS() << "All operands defined outside \n"); 60 61 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist. 62 Value v; 63 if (op->getNumResults() > 0) { 64 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc"); 65 v = op->getResult(0); 66 } 67 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) { 68 return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner()); 69 })) { 70 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n"); 71 return; 72 } 73 74 // Move AllocOp before the loop. 75 if (isa<AllocOp, AllocaOp>(op)) 76 loop.moveOutOfLoop({op}); 77 else // Move DeallocOp outside of the loop. 78 op->moveAfter(loop); 79 changed = true; 80 }); 81 } 82 } 83 84 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 85 bool changed = true; 86 while (changed) { 87 changed = false; 88 89 func.walk([&](vector::TransferReadOp transferRead) { 90 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 91 << *transferRead.getOperation() << "\n"); 92 auto loop = dyn_cast<scf::ForOp>(transferRead->getParentOp()); 93 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() 94 << "\n"); 95 if (!loop) 96 return WalkResult::advance(); 97 98 if (failed(moveLoopInvariantCode( 99 cast<LoopLikeOpInterface>(loop.getOperation())))) 100 llvm_unreachable( 101 "Unexpected failure to move invariant code out of loop"); 102 103 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 104 << "\n"); 105 106 llvm::SetVector<Operation *> forwardSlice; 107 getForwardSlice(transferRead, &forwardSlice); 108 109 // Look for the last TransferWriteOp in the forwardSlice of 110 // `transferRead` that operates on the same memref. 111 vector::TransferWriteOp transferWrite; 112 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 113 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 114 if (!candidateWrite || candidateWrite.source() != transferRead.source()) 115 continue; 116 transferWrite = candidateWrite; 117 } 118 119 // All operands of the TransferRead must be defined outside of the loop. 120 for (auto operand : transferRead.getOperands()) 121 if (!loop.isDefinedOutsideOfLoop(operand)) 122 return WalkResult::advance(); 123 124 // Only hoist transfer_read / transfer_write pairs for now. 125 if (!transferWrite) 126 return WalkResult::advance(); 127 128 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 129 << "\n"); 130 131 // Approximate aliasing by checking that: 132 // 1. indices are the same, 133 // 2. no other operations in the loop access the same memref except 134 // for transfer_read/transfer_write accessing statically disjoint 135 // slices. 136 if (transferRead.indices() != transferWrite.indices() && 137 transferRead.getVectorType() == transferWrite.getVectorType()) 138 return WalkResult::advance(); 139 140 // TODO: may want to memoize this information for performance but it 141 // likely gets invalidated often. 142 DominanceInfo dom(loop); 143 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 144 return WalkResult::advance(); 145 for (auto &use : transferRead.source().getUses()) { 146 if (!dom.properlyDominates(loop, use.getOwner())) 147 continue; 148 if (use.getOwner() == transferRead.getOperation() || 149 use.getOwner() == transferWrite.getOperation()) 150 continue; 151 if (auto transferWriteUse = 152 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 153 if (!isDisjointTransferSet( 154 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 155 cast<VectorTransferOpInterface>( 156 transferWriteUse.getOperation()))) 157 return WalkResult::advance(); 158 } else if (auto transferReadUse = 159 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 160 if (!isDisjointTransferSet( 161 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 162 cast<VectorTransferOpInterface>( 163 transferReadUse.getOperation()))) 164 return WalkResult::advance(); 165 } else { 166 // Unknown use, we cannot prove that it doesn't alias with the 167 // transferRead/transferWrite operations. 168 return WalkResult::advance(); 169 } 170 } 171 172 // Hoist read before. 173 if (failed(loop.moveOutOfLoop({transferRead}))) 174 llvm_unreachable( 175 "Unexpected failure to move transfer read out of loop"); 176 177 // Hoist write after. 178 transferWrite->moveAfter(loop); 179 180 // Rewrite `loop` with new yields by cloning and erase the original loop. 181 OpBuilder b(transferRead); 182 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 183 transferWrite.vector()); 184 185 // Transfer write has been hoisted, need to update the written value to 186 // the value yielded by the newForOp. 187 transferWrite.vector().replaceAllUsesWith( 188 newForOp.getResults().take_back()[0]); 189 190 changed = true; 191 loop.erase(); 192 // Need to interrupt and restart because erasing the loop messes up the 193 // walk. 194 return WalkResult::interrupt(); 195 }); 196 } 197 } 198