1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/Dominance.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/Transforms/LoopUtils.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-hoisting" 28 29 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 using llvm::dbgs; 35 36 void mlir::linalg::hoistViewAllocOps(FuncOp func) { 37 bool changed = true; 38 while (changed) { 39 changed = false; 40 func.walk([&changed](Operation *op) { 41 if (!isa<AllocOp, AllocaOp, DeallocOp>(op)) 42 return; 43 44 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n"); 45 auto loop = dyn_cast<scf::ForOp>(op->getParentOp()); 46 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n"); 47 48 // Only hoist out of immediately enclosing scf::ForOp. 49 if (!loop) 50 return; 51 52 // If any operand is defined inside the loop don't hoist. 53 if (llvm::any_of(op->getOperands(), [&](Value v) { 54 return !loop.isDefinedOutsideOfLoop(v); 55 })) 56 return; 57 58 LLVM_DEBUG(DBGS() << "All operands defined outside \n"); 59 60 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist. 61 Value v; 62 if (op->getNumResults() > 0) { 63 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc"); 64 v = op->getResult(0); 65 } 66 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) { 67 return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner()); 68 })) { 69 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n"); 70 return; 71 } 72 73 // Move AllocOp before the loop. 74 if (isa<AllocOp, AllocaOp>(op)) 75 loop.moveOutOfLoop({op}); 76 else // Move DeallocOp outside of the loop. 77 op->moveAfter(loop); 78 changed = true; 79 }); 80 } 81 } 82 83 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 84 bool changed = true; 85 while (changed) { 86 changed = false; 87 88 func.walk([&](vector::TransferReadOp transferRead) { 89 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 90 << *transferRead.getOperation() << "\n"); 91 auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp()); 92 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp() 93 << "\n"); 94 if (!loop) 95 return WalkResult::advance(); 96 97 if (failed(moveLoopInvariantCode( 98 cast<LoopLikeOpInterface>(loop.getOperation())))) 99 llvm_unreachable( 100 "Unexpected failure to move invariant code out of loop"); 101 102 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 103 << "\n"); 104 105 llvm::SetVector<Operation *> forwardSlice; 106 getForwardSlice(transferRead, &forwardSlice); 107 108 // Look for the last TransferWriteOp in the forwardSlice of 109 // `transferRead` that operates on the same memref. 110 vector::TransferWriteOp transferWrite; 111 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 112 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 113 if (!candidateWrite || candidateWrite.memref() != transferRead.memref()) 114 continue; 115 transferWrite = candidateWrite; 116 } 117 118 // All operands of the TransferRead must be defined outside of the loop. 119 for (auto operand : transferRead.getOperands()) 120 if (!loop.isDefinedOutsideOfLoop(operand)) 121 return WalkResult::advance(); 122 123 // Only hoist transfer_read / transfer_write pairs for now. 124 if (!transferWrite) 125 return WalkResult::advance(); 126 127 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 128 << "\n"); 129 130 // Approximate aliasing by checking that: 131 // 1. indices are the same, 132 // 2. no other use either dominates the transfer_read or is dominated 133 // by the transfer_write (i.e. aliasing between the write and the read 134 // across the loop). 135 if (transferRead.indices() != transferWrite.indices()) 136 return WalkResult::advance(); 137 138 // TODO: may want to memoize this information for performance but it 139 // likely gets invalidated often. 140 DominanceInfo dom(loop); 141 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 142 return WalkResult::advance(); 143 for (auto &use : transferRead.memref().getUses()) 144 if (dom.properlyDominates(use.getOwner(), 145 transferRead.getOperation()) || 146 dom.properlyDominates(transferWrite, use.getOwner())) 147 return WalkResult::advance(); 148 149 // Hoist read before. 150 if (failed(loop.moveOutOfLoop({transferRead}))) 151 llvm_unreachable( 152 "Unexpected failure to move transfer read out of loop"); 153 154 // Hoist write after. 155 transferWrite.getOperation()->moveAfter(loop); 156 157 // Rewrite `loop` with new yields by cloning and erase the original loop. 158 OpBuilder b(transferRead); 159 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 160 transferWrite.vector()); 161 162 // Transfer write has been hoisted, need to update the written value to 163 // the value yielded by the newForOp. 164 transferWrite.vector().replaceAllUsesWith( 165 newForOp.getResults().take_back()[0]); 166 167 changed = true; 168 loop.erase(); 169 // Need to interrupt and restart because erasing the loop messes up the 170 // walk. 171 return WalkResult::interrupt(); 172 }); 173 } 174 } 175