1 //===- Hoisting.cpp - Linalg hoisting transformations ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements functions concerned with hoisting invariant operations 10 // in the context of Linalg transformations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 15 #include "mlir/Analysis/SliceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Dialect/SCF/Utils.h" 19 #include "mlir/Dialect/StandardOps/IR/Ops.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/IR/Dominance.h" 22 #include "mlir/IR/Function.h" 23 #include "mlir/Transforms/LoopUtils.h" 24 #include "llvm/ADT/StringRef.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "linalg-hoisting" 28 29 #define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 using llvm::dbgs; 35 36 void mlir::linalg::hoistViewAllocOps(FuncOp func) { 37 bool changed = true; 38 while (changed) { 39 changed = false; 40 func.walk([&changed](Operation *op) { 41 if (!isa<AllocOp, AllocaOp, DeallocOp>(op)) 42 return; 43 44 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *op << "\n"); 45 auto loop = dyn_cast<scf::ForOp>(op->getParentOp()); 46 LLVM_DEBUG(DBGS() << "Parent op: " << *op->getParentOp() << "\n"); 47 48 // Only hoist out of immediately enclosing scf::ForOp. 49 if (!loop) 50 return; 51 52 // If any operand is defined inside the loop don't hoist. 53 if (llvm::any_of(op->getOperands(), [&](Value v) { 54 return !loop.isDefinedOutsideOfLoop(v); 55 })) 56 return; 57 58 LLVM_DEBUG(DBGS() << "All operands defined outside \n"); 59 60 // If alloc has other uses than ViewLikeOp and DeallocOp don't hoist. 61 Value v; 62 if (op->getNumResults() > 0) { 63 assert(op->getNumResults() == 1 && "Unexpected multi-result alloc"); 64 v = op->getResult(0); 65 } 66 if (v && !llvm::all_of(v.getUses(), [&](OpOperand &operand) { 67 return isa<ViewLikeOpInterface, DeallocOp>(operand.getOwner()); 68 })) { 69 LLVM_DEBUG(DBGS() << "Found non view-like or dealloc use: bail\n"); 70 return; 71 } 72 73 // Move AllocOp before the loop. 74 if (isa<AllocOp, AllocaOp>(op)) 75 loop.moveOutOfLoop({op}); 76 else // Move DeallocOp outside of the loop. 77 op->moveAfter(loop); 78 changed = true; 79 }); 80 } 81 } 82 83 /// Return true if we can prove that the transfer operations access dijoint 84 /// memory. 85 static bool isDisjoint(VectorTransferOpInterface transferA, 86 VectorTransferOpInterface transferB) { 87 if (transferA.memref() != transferB.memref()) 88 return false; 89 // For simplicity only look at transfer of same type. 90 if (transferA.getVectorType() != transferB.getVectorType()) 91 return false; 92 unsigned rankOffset = transferA.getLeadingMemRefRank(); 93 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { 94 auto indexA = transferA.indices()[i].getDefiningOp<ConstantOp>(); 95 auto indexB = transferB.indices()[i].getDefiningOp<ConstantOp>(); 96 // If any of the indices are dynamic we cannot prove anything. 97 if (!indexA || !indexB) 98 continue; 99 100 if (i < rankOffset) { 101 // For dimension used as index if we can prove that index are different we 102 // know we are accessing disjoint slices. 103 if (indexA.getValue().cast<IntegerAttr>().getInt() != 104 indexB.getValue().cast<IntegerAttr>().getInt()) 105 return true; 106 } else { 107 // For this dimension, we slice a part of the memref we need to make sure 108 // the intervals accessed don't overlap. 109 int64_t distance = 110 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() - 111 indexB.getValue().cast<IntegerAttr>().getInt()); 112 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) 113 return true; 114 } 115 } 116 return false; 117 } 118 119 void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { 120 bool changed = true; 121 while (changed) { 122 changed = false; 123 124 func.walk([&](vector::TransferReadOp transferRead) { 125 LLVM_DEBUG(DBGS() << "Candidate for hoisting: " 126 << *transferRead.getOperation() << "\n"); 127 auto loop = dyn_cast<scf::ForOp>(transferRead.getParentOp()); 128 LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead.getParentOp() 129 << "\n"); 130 if (!loop) 131 return WalkResult::advance(); 132 133 if (failed(moveLoopInvariantCode( 134 cast<LoopLikeOpInterface>(loop.getOperation())))) 135 llvm_unreachable( 136 "Unexpected failure to move invariant code out of loop"); 137 138 LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() 139 << "\n"); 140 141 llvm::SetVector<Operation *> forwardSlice; 142 getForwardSlice(transferRead, &forwardSlice); 143 144 // Look for the last TransferWriteOp in the forwardSlice of 145 // `transferRead` that operates on the same memref. 146 vector::TransferWriteOp transferWrite; 147 for (auto *sliceOp : llvm::reverse(forwardSlice)) { 148 auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp); 149 if (!candidateWrite || candidateWrite.memref() != transferRead.memref()) 150 continue; 151 transferWrite = candidateWrite; 152 } 153 154 // All operands of the TransferRead must be defined outside of the loop. 155 for (auto operand : transferRead.getOperands()) 156 if (!loop.isDefinedOutsideOfLoop(operand)) 157 return WalkResult::advance(); 158 159 // Only hoist transfer_read / transfer_write pairs for now. 160 if (!transferWrite) 161 return WalkResult::advance(); 162 163 LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() 164 << "\n"); 165 166 // Approximate aliasing by checking that: 167 // 1. indices are the same, 168 // 2. no other operations in the loop access the same memref except 169 // for transfer_read/transfer_write accessing statically disjoint 170 // slices. 171 if (transferRead.indices() != transferWrite.indices() && 172 transferRead.getVectorType() == transferWrite.getVectorType()) 173 return WalkResult::advance(); 174 175 // TODO: may want to memoize this information for performance but it 176 // likely gets invalidated often. 177 DominanceInfo dom(loop); 178 if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) 179 return WalkResult::advance(); 180 for (auto &use : transferRead.memref().getUses()) { 181 if (!dom.properlyDominates(loop, use.getOwner())) 182 continue; 183 if (use.getOwner() == transferRead.getOperation() || 184 use.getOwner() == transferWrite.getOperation()) 185 continue; 186 if (auto transferWriteUse = 187 dyn_cast<vector::TransferWriteOp>(use.getOwner())) { 188 if (!isDisjoint( 189 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 190 cast<VectorTransferOpInterface>( 191 transferWriteUse.getOperation()))) 192 return WalkResult::advance(); 193 } else if (auto transferReadUse = 194 dyn_cast<vector::TransferReadOp>(use.getOwner())) { 195 if (!isDisjoint( 196 cast<VectorTransferOpInterface>(transferWrite.getOperation()), 197 cast<VectorTransferOpInterface>( 198 transferReadUse.getOperation()))) 199 return WalkResult::advance(); 200 } else { 201 // Unknown use, we cannot prove that it doesn't alias with the 202 // transferRead/transferWrite operations. 203 return WalkResult::advance(); 204 } 205 } 206 207 // Hoist read before. 208 if (failed(loop.moveOutOfLoop({transferRead}))) 209 llvm_unreachable( 210 "Unexpected failure to move transfer read out of loop"); 211 212 // Hoist write after. 213 transferWrite.getOperation()->moveAfter(loop); 214 215 // Rewrite `loop` with new yields by cloning and erase the original loop. 216 OpBuilder b(transferRead); 217 auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), 218 transferWrite.vector()); 219 220 // Transfer write has been hoisted, need to update the written value to 221 // the value yielded by the newForOp. 222 transferWrite.vector().replaceAllUsesWith( 223 newForOp.getResults().take_back()[0]); 224 225 changed = true; 226 loop.erase(); 227 // Need to interrupt and restart because erasing the loop messes up the 228 // walk. 229 return WalkResult::interrupt(); 230 }); 231 } 232 } 233