1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/Dominance.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/Transforms/LoopUtils.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-hoisting" 28 29 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 using llvm::dbgs; 35 36 void mlir::linalg::hoistViewAllocOps(FuncOp func) { 37 bool changed = true; 38 while (changed) { 39 changed = false; 40 func.walk([&changed](Operation *op) { 41 if (!isa<AllocOp>(op) && !isa<AllocaOp>(op) && !isa<DeallocOp>(op)) 42 return; 43 44 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n"); 45 auto loop = dyn_cast<scf::ForOp>(op->getParentOp()); 46 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n"); 47 48 // Only hoist out of immediately enclosing scf::ForOp. 49 if (!loop) 50 return; 51 52 // If any operand is defined inside the loop don't hoist. 53 if (llvm::any_of(op->getOperands(), [&](Value v) { 54 return !loop.isDefinedOutsideOfLoop(v); 55 })) 56 return; 57 58 LLVM_DEBUG(DBGS() << "All operands defined outside \n"); 59 60 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist. 61 Value v; 62 if (op->getNumResults() > 0) { 63 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc"); 64 v = op->getResult(0); 65 } 66 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) { 67 return isa<ViewLikeOpInterface>(operand.getOwner()) || 68 isa<DeallocOp>(operand.getOwner()); 69 })) { 70 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n"); 71 return; 72 } 73 74 // Move AllocOp before the loop. 75 if (isa<AllocOp>(op) || isa<AllocaOp>(op)) 76 loop.moveOutOfLoop({op}); 77 else // Move DeallocOp outside of the loop. 78 op->moveAfter(loop); 79 changed = true; 80 }); 81 } 82 } 83 84 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 85 bool changed = true; 86 while (changed) { 87 changed = false; 88 89 func.walk([&](vector::TransferReadOp transferRead) { 90 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 91 << *transferRead.getOperation() << "\n"); 92 auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp()); 93 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp() 94 << "\n"); 95 if (!loop) 96 return WalkResult::advance(); 97 98 if (failed(moveLoopInvariantCode( 99 cast<LoopLikeOpInterface>(loop.getOperation())))) 100 llvm_unreachable( 101 "Unexpected failure to move invariant code out of loop"); 102 103 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 104 << "\n"); 105 106 llvm::SetVector<Operation *> forwardSlice; 107 getForwardSlice(transferRead, &forwardSlice); 108 109 // Look for the last TransferWriteOp in the forwardSlice of 110 // `transferRead` that operates on the same memref. 111 vector::TransferWriteOp transferWrite; 112 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 113 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 114 if (!candidateWrite || candidateWrite.memref() != transferRead.memref()) 115 continue; 116 transferWrite = candidateWrite; 117 } 118 119 // All operands of the TransferRead must be defined outside of the loop. 120 for (auto operand : transferRead.getOperands()) 121 if (!loop.isDefinedOutsideOfLoop(operand)) 122 return WalkResult::advance(); 123 124 // Only hoist transfer_read / transfer_write pairs for now. 125 if (!transferWrite) 126 return WalkResult::advance(); 127 128 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 129 << "\n"); 130 131 // Approximate aliasing by checking that: 132 // 1. indices are the same, 133 // 2. no other use either dominates the transfer_read or is dominated 134 // by the transfer_write (i.e. aliasing between the write and the read 135 // across the loop). 136 if (transferRead.indices() != transferWrite.indices()) 137 return WalkResult::advance(); 138 139 // TODO: may want to memoize this information for performance but it 140 // likely gets invalidated often. 141 DominanceInfo dom(loop); 142 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 143 return WalkResult::advance(); 144 for (auto &use : transferRead.memref().getUses()) 145 if (dom.properlyDominates(use.getOwner(), 146 transferRead.getOperation()) || 147 dom.properlyDominates(transferWrite, use.getOwner())) 148 return WalkResult::advance(); 149 150 // Hoist read before. 151 if (failed(loop.moveOutOfLoop({transferRead}))) 152 llvm_unreachable( 153 "Unexpected failure to move transfer read out of loop"); 154 155 // Hoist write after. 156 transferWrite.getOperation()->moveAfter(loop); 157 158 // Rewrite `loop` with new yields by cloning and erase the original loop. 159 OpBuilder b(transferRead); 160 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 161 transferWrite.vector()); 162 163 // Transfer write has been hoisted, need to update the written value to 164 // the value yielded by the newForOp. 165 transferWrite.vector().replaceAllUsesWith( 166 newForOp.getResults().take_back()[0]); 167 168 changed = true; 169 loop.erase(); 170 // Need to interrupt and restart because erasing the loop messes up the 171 // walk. 172 return WalkResult::interrupt(); 173 }); 174 } 175 } 176