1 //===- AllocTensorElimination.cpp - alloc_tensor op elimination -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Bufferization/Transforms/AllocTensorElimination.h" 14 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" 15 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 16 #include "mlir/Dialect/Tensor/IR/Tensor.h" 17 #include "mlir/IR/Dominance.h" 18 #include "mlir/Pass/Pass.h" 19 20 using namespace mlir; 21 using namespace mlir::bufferization; 22 23 /// Return true if all `neededValues` are in scope at the given 24 /// `insertionPoint`. 25 static bool 26 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo, 27 Operation *insertionPoint, 28 const SmallVector<Value> &neededValues) { 29 for (Value val : neededValues) { 30 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 31 Block *owner = bbArg.getOwner(); 32 if (!owner->findAncestorOpInBlock(*insertionPoint)) 33 return false; 34 } else { 35 auto opResult = val.cast<OpResult>(); 36 if (!domInfo.dominates(opResult.getOwner(), insertionPoint)) 37 return false; 38 } 39 } 40 return true; 41 } 42 43 /// Return true if the given `insertionPoint` dominates all uses of 44 /// `allocTensorOp`. 45 static bool insertionPointDominatesUses(const DominanceInfo &domInfo, 46 Operation *insertionPoint, 47 Operation *allocTensorOp) { 48 for (Operation *user : allocTensorOp->getUsers()) 49 if (!domInfo.dominates(insertionPoint, user)) 50 return false; 51 return true; 52 } 53 54 /// Find a valid insertion point for a replacement of `allocTensorOp`, assuming 55 /// that the replacement may use any value from `neededValues`. 56 static Operation * 57 findValidInsertionPoint(Operation *allocTensorOp, 58 const SmallVector<Value> &neededValues) { 59 DominanceInfo domInfo; 60 61 // Gather all possible insertion points: the location of `allocTensorOp` and 62 // right after the definition of each value in `neededValues`. 63 SmallVector<Operation *> insertionPointCandidates; 64 insertionPointCandidates.push_back(allocTensorOp); 65 for (Value val : neededValues) { 66 // Note: The anchor op is using all of `neededValues`, so: 67 // * in case of a block argument: There must be at least one op in the block 68 // (the anchor op or one of its parents). 69 // * in case of an OpResult: There must be at least one op right after the 70 // defining op (the anchor op or one of its 71 // parents). 72 if (auto bbArg = val.dyn_cast<BlockArgument>()) { 73 insertionPointCandidates.push_back( 74 &bbArg.getOwner()->getOperations().front()); 75 } else { 76 insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode()); 77 } 78 } 79 80 // Select first matching insertion point. 81 for (Operation *insertionPoint : insertionPointCandidates) { 82 // Check if all needed values are in scope. 83 if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint, 84 neededValues)) 85 continue; 86 // Check if the insertion point is before all uses. 87 if (!insertionPointDominatesUses(domInfo, insertionPoint, allocTensorOp)) 88 continue; 89 return insertionPoint; 90 } 91 92 // No suitable insertion point was found. 93 return nullptr; 94 } 95 96 /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp is replaced 97 /// with the result of `rewriteFunc` if it is anchored on a matching 98 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def 99 /// chain, starting from the OpOperand and always following the aliasing 100 /// OpOperand, that eventually ends at a single AllocTensorOp. 101 LogicalResult mlir::bufferization::eliminateAllocTensors( 102 RewriterBase &rewriter, Operation *op, AnalysisState &state, 103 AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc) { 104 OpBuilder::InsertionGuard g(rewriter); 105 106 WalkResult status = op->walk([&](Operation *op) { 107 for (OpOperand &operand : op->getOpOperands()) { 108 // Skip operands that do not bufferize inplace. 109 if (!state.isInPlace(operand)) 110 continue; 111 // All values that are needed to create the replacement op. 112 SmallVector<Value> neededValues; 113 // Is this a matching OpOperand? 114 if (!anchorMatchFunc(operand, neededValues)) 115 continue; 116 SetVector<Value> maybeAllocTensor = 117 state.findValueInReverseUseDefChain(operand.get(), [&](Value val) { 118 // Continue traversal until this function returns true. 119 OpResult opResult = val.dyn_cast<OpResult>(); 120 if (!opResult) 121 return true; 122 SmallVector<OpOperand *> opOperands = 123 state.getAliasingOpOperand(opResult); 124 if (!llvm::all_of(opOperands, [&](OpOperand *operand) { 125 return state.isInPlace(*operand); 126 })) 127 return true; 128 // Only equivalent tensors are supported at the moment. 129 // TODO: Support cases such as extract_slice(alloc_tensor) 130 return !llvm::all_of(opOperands, [&](OpOperand *operand) { 131 return state.areEquivalentBufferizedValues(operand->get(), 132 opResult); 133 }); 134 }); 135 136 // Replace only if the reverse use-def chain ends at exactly one 137 // AllocTensorOp. 138 if (maybeAllocTensor.size() != 1 || 139 !maybeAllocTensor.front().getDefiningOp<AllocTensorOp>()) 140 return WalkResult::skip(); 141 Value allocTensor = maybeAllocTensor.front(); 142 143 // Find a suitable insertion point. 144 Operation *insertionPoint = 145 findValidInsertionPoint(allocTensor.getDefiningOp(), neededValues); 146 if (!insertionPoint) 147 continue; 148 149 // Create a replacement for the AllocTensorOp. 150 rewriter.setInsertionPoint(insertionPoint); 151 Value replacement = rewriteFunc(rewriter, allocTensor.getLoc(), operand); 152 if (!replacement) 153 continue; 154 155 // Replace the AllocTensorOp. 156 rewriter.replaceOp(allocTensor.getDefiningOp(), replacement); 157 } 158 159 // Advance to the next operation. 160 return WalkResult::advance(); 161 }); 162 163 return failure(status.wasInterrupted()); 164 } 165 166 /// Try to eliminate AllocTensorOps inside `op`. An AllocTensorOp can be 167 /// eliminated if it is eventually inserted into another tensor (and some other 168 /// conditions are met). 169 /// 170 /// E.g.: 171 /// %0 = linalg.alloc_tensor 172 /// %1 = linalg.fill(%cst, %0) {inplace = [true]} 173 /// %2 = tensor.insert_slice %1 into %t[10][20][1] 174 /// 175 /// AllocTensorOp elimination will try to fill %t inplace instead of filling a 176 /// new allocation %0 and inserting it into %t. This is done by replacing the 177 /// AllocTensorOp with: 178 /// 179 /// %0 = tensor.extract_slice %t[10][20][1] 180 /// 181 /// The analysis looks for matching ExtractSliceOp/InsertSliceOp pairs and lets 182 /// those bufferize inplace in the absence of other conflicts. 183 /// 184 /// Starting from an InsertSliceOp, an AllocTensorOp at the end of the insert 185 /// source's reverse use-def chain is eliminated if: 186 /// * On the reverse use-def chain path from the InsertSliceOp to the 187 /// AllocTensorOp, all ops were decided to bufferize inplace and the buffer 188 /// relation is "equivalent" (TODO: can be relaxed if needed). 189 /// * The reverse use-def chain has exactly one end, which is the AllocTensorOp. 190 LogicalResult 191 mlir::bufferization::insertSliceAnchoredAllocTensorEliminationStep( 192 RewriterBase &rewriter, Operation *op, AnalysisState &state) { 193 return eliminateAllocTensors( 194 rewriter, op, state, 195 /*anchorMatchFunc=*/ 196 [&](OpOperand &operand, SmallVector<Value> &neededValues) { 197 auto insertSliceOp = 198 dyn_cast<tensor::InsertSliceOp>(operand.getOwner()); 199 if (!insertSliceOp) 200 return false; 201 if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) 202 return false; 203 204 // Collect all values that are needed to construct the replacement op. 205 neededValues.append(insertSliceOp.getOffsets().begin(), 206 insertSliceOp.getOffsets().end()); 207 neededValues.append(insertSliceOp.getSizes().begin(), 208 insertSliceOp.getSizes().end()); 209 neededValues.append(insertSliceOp.getStrides().begin(), 210 insertSliceOp.getStrides().end()); 211 neededValues.push_back(insertSliceOp.getDest()); 212 213 return true; 214 }, 215 /*rewriteFunc=*/ 216 [](OpBuilder &b, Location loc, OpOperand &operand) { 217 auto insertOp = cast<tensor::InsertSliceOp>(operand.getOwner()); 218 auto extractOp = b.create<tensor::ExtractSliceOp>( 219 loc, insertOp.getSourceType(), insertOp.getDest(), 220 insertOp.getMixedOffsets(), insertOp.getMixedSizes(), 221 insertOp.getMixedStrides()); 222 return extractOp.getResult(); 223 }); 224 } 225 226 namespace { 227 struct AllocTensorElimination 228 : public AllocTensorEliminationBase<AllocTensorElimination> { 229 AllocTensorElimination() = default; 230 231 void runOnOperation() override; 232 233 void getDependentDialects(DialectRegistry ®istry) const override { 234 registry 235 .insert<bufferization::BufferizationDialect, tensor::TensorDialect>(); 236 } 237 }; 238 } // namespace 239 240 void AllocTensorElimination::runOnOperation() { 241 Operation *op = getOperation(); 242 OneShotBufferizationOptions options; 243 OneShotAnalysisState state(op, options); 244 if (failed(analyzeOp(op, state))) { 245 signalPassFailure(); 246 return; 247 } 248 249 IRRewriter rewriter(op->getContext()); 250 if (failed(bufferization::insertSliceAnchoredAllocTensorEliminationStep( 251 rewriter, op, state))) 252 signalPassFailure(); 253 } 254 255 std::unique_ptr<Pass> mlir::bufferization::createAllocTensorEliminationPass() { 256 return std::make_unique<AllocTensorElimination>(); 257 } 258