1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::shape; 21 22 namespace mlir { 23 namespace shape { 24 namespace { 25 26 /// Bufferization of shape.assuming. 27 struct AssumingOpInterface 28 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface, 29 shape::AssumingOp> { 30 SmallVector<OpOperand *> 31 getAliasingOpOperand(Operation *op, OpResult opResult, 32 const AnalysisState &state) const { 33 // AssumingOps do not have tensor OpOperands. The yielded value can be any 34 // SSA value that is in scope. To allow for use-def chain traversal through 35 // AssumingOps in the analysis, the corresponding yield value is considered 36 // to be aliasing with the result. 37 auto assumingOp = cast<shape::AssumingOp>(op); 38 size_t resultNum = std::distance(op->getOpResults().begin(), 39 llvm::find(op->getOpResults(), opResult)); 40 // TODO: Support multiple blocks. 41 assert(assumingOp.getDoRegion().getBlocks().size() == 1 && 42 "expected exactly 1 block"); 43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>( 44 assumingOp.getDoRegion().front().getTerminator()); 45 assert(yieldOp && "expected shape.assuming_yield terminator"); 46 return {&yieldOp->getOpOperand(resultNum)}; 47 } 48 49 // TODO: For better bufferization results, this could return `true` only if 50 // there is a memory write in the region. 51 bool isMemoryWrite(Operation *op, OpResult opResult, 52 const AnalysisState &state) const { 53 // Similar to scf.if, results of this op are always considered memory writes 54 // in the analysis. This is a useful pattern for all ops that have tensor 55 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 56 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 57 // ops without OpOperands. 58 return true; 59 } 60 61 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 62 BufferizationState &state) const { 63 auto assumingOp = cast<shape::AssumingOp>(op); 64 65 // Compute new result types. 66 SmallVector<Type> newResultTypes; 67 for (Type type : assumingOp->getResultTypes()) { 68 if (auto tensorType = type.dyn_cast<TensorType>()) { 69 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 70 } else { 71 newResultTypes.push_back(type); 72 } 73 } 74 75 // Create new op and move over region. 76 auto newOp = rewriter.create<shape::AssumingOp>( 77 op->getLoc(), newResultTypes, assumingOp.getWitness()); 78 newOp.getDoRegion().takeBody(assumingOp.getRegion()); 79 80 // Update terminator. 81 assert(newOp.getDoRegion().getBlocks().size() == 1 && 82 "only 1 block supported"); 83 Block *newBlock = &newOp.getDoRegion().front(); 84 auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator()); 85 rewriter.setInsertionPoint(yieldOp); 86 SmallVector<Value> newYieldValues; 87 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 88 Value val = it.value(); 89 if (val.getType().isa<TensorType>()) { 90 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 91 yieldOp.getLoc(), newResultTypes[it.index()], val)); 92 } else { 93 newYieldValues.push_back(val); 94 } 95 } 96 rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp, 97 newYieldValues); 98 99 // Update all uses of the old op. 100 rewriter.setInsertionPointAfter(newOp); 101 SmallVector<Value> newResults; 102 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { 103 if (it.value().isa<TensorType>()) { 104 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 105 assumingOp.getLoc(), newOp->getResult(it.index()))); 106 } else { 107 newResults.push_back(newOp->getResult(it.index())); 108 } 109 } 110 111 // Replace old op. 112 rewriter.replaceOp(assumingOp, newResults); 113 114 return success(); 115 } 116 117 BufferRelation bufferRelation(Operation *op, OpResult opResult, 118 const AnalysisState &state) const { 119 return BufferRelation::Equivalent; 120 } 121 122 bool isAllocationHoistingBarrier(Operation *op) const { 123 // Allocations should not be hoisted out of AssumingOps. 124 return true; 125 } 126 }; 127 128 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing 129 /// ops, so this is for analysis only. 130 struct AssumingYieldOpInterface 131 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface, 132 shape::AssumingOp> { 133 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 134 const AnalysisState &state) const { 135 return true; 136 } 137 138 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 139 const AnalysisState &state) const { 140 return false; 141 } 142 143 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 144 const AnalysisState &state) const { 145 assert(isa<shape::AssumingOp>(op->getParentOp()) && 146 "expected that parent is an AssumingOp"); 147 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 148 } 149 150 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 151 const AnalysisState &state) const { 152 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 153 // may be generated inside the block. We should not return/yield allocations 154 // when possible. 155 return true; 156 } 157 158 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 159 BufferizationState &state) const { 160 // Op is bufferized as part of AssumingOp. 161 return failure(); 162 } 163 }; 164 165 } // namespace 166 } // namespace shape 167 } // namespace mlir 168 169 void mlir::shape::registerBufferizableOpInterfaceExternalModels( 170 DialectRegistry ®istry) { 171 registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) { 172 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx); 173 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx); 174 }); 175 } 176