1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::shape; 21 22 namespace mlir { 23 namespace shape { 24 namespace { 25 26 /// Bufferization of shape.assuming. 27 struct AssumingOpInterface 28 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface, 29 shape::AssumingOp> { 30 SmallVector<OpOperand *> 31 getAliasingOpOperand(Operation *op, OpResult opResult, 32 const AnalysisState &state) const { 33 // AssumingOps do not have tensor OpOperands. The yielded value can be any 34 // SSA value that is in scope. To allow for use-def chain traversal through 35 // AssumingOps in the analysis, the corresponding yield value is considered 36 // to be aliasing with the result. 37 auto assumingOp = cast<shape::AssumingOp>(op); 38 size_t resultNum = std::distance(op->getOpResults().begin(), 39 llvm::find(op->getOpResults(), opResult)); 40 // TODO: Support multiple blocks. 41 assert(assumingOp.getDoRegion().getBlocks().size() == 1 && 42 "expected exactly 1 block"); 43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>( 44 assumingOp.getDoRegion().front().getTerminator()); 45 assert(yieldOp && "expected shape.assuming_yield terminator"); 46 return {&yieldOp->getOpOperand(resultNum)}; 47 } 48 49 // TODO: For better bufferization results, this could return `true` only if 50 // there is a memory write in the region. 51 bool isMemoryWrite(Operation *op, OpResult opResult, 52 const AnalysisState &state) const { 53 // Similar to scf.if, results of this op are always considered memory writes 54 // in the analysis. This is a useful pattern for all ops that have tensor 55 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 56 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 57 // ops without OpOperands. 58 return true; 59 } 60 61 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 62 BufferizationState &state) const { 63 auto assumingOp = cast<shape::AssumingOp>(op); 64 65 // Compute new result types. 66 SmallVector<Type> newResultTypes; 67 for (Type type : assumingOp->getResultTypes()) { 68 if (auto tensorType = type.dyn_cast<TensorType>()) { 69 // TODO: Infer the result type instead of computing it. 70 newResultTypes.push_back(getMemRefType(tensorType, state.getOptions())); 71 } else { 72 newResultTypes.push_back(type); 73 } 74 } 75 76 // Create new op and move over region. 77 auto newOp = rewriter.create<shape::AssumingOp>( 78 op->getLoc(), newResultTypes, assumingOp.getWitness()); 79 newOp.getDoRegion().takeBody(assumingOp.getRegion()); 80 81 // Update terminator. 82 assert(newOp.getDoRegion().getBlocks().size() == 1 && 83 "only 1 block supported"); 84 Block *newBlock = &newOp.getDoRegion().front(); 85 auto yieldOp = cast<shape::AssumingYieldOp>(newBlock->getTerminator()); 86 rewriter.setInsertionPoint(yieldOp); 87 SmallVector<Value> newYieldValues; 88 for (const auto &it : llvm::enumerate(yieldOp.operands())) { 89 Value val = it.value(); 90 if (val.getType().isa<TensorType>()) { 91 newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>( 92 yieldOp.getLoc(), newResultTypes[it.index()], val)); 93 } else { 94 newYieldValues.push_back(val); 95 } 96 } 97 rewriter.replaceOpWithNewOp<shape::AssumingYieldOp>(yieldOp, 98 newYieldValues); 99 100 // Update all uses of the old op. 101 rewriter.setInsertionPointAfter(newOp); 102 SmallVector<Value> newResults; 103 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { 104 if (it.value().isa<TensorType>()) { 105 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 106 assumingOp.getLoc(), newOp->getResult(it.index()))); 107 } else { 108 newResults.push_back(newOp->getResult(it.index())); 109 } 110 } 111 112 // Replace old op. 113 rewriter.replaceOp(assumingOp, newResults); 114 115 return success(); 116 } 117 118 BufferRelation bufferRelation(Operation *op, OpResult opResult, 119 const AnalysisState &state) const { 120 return BufferRelation::Equivalent; 121 } 122 }; 123 124 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing 125 /// ops, so this is for analysis only. 126 struct AssumingYieldOpInterface 127 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface, 128 shape::AssumingYieldOp> { 129 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 130 const AnalysisState &state) const { 131 return true; 132 } 133 134 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 135 const AnalysisState &state) const { 136 return false; 137 } 138 139 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 140 const AnalysisState &state) const { 141 assert(isa<shape::AssumingOp>(op->getParentOp()) && 142 "expected that parent is an AssumingOp"); 143 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 144 } 145 146 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 147 const AnalysisState &state) const { 148 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 149 // may be generated inside the block. We should not return/yield allocations 150 // when possible. 151 return true; 152 } 153 154 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 155 BufferizationState &state) const { 156 // Op is bufferized as part of AssumingOp. 157 return failure(); 158 } 159 }; 160 161 } // namespace 162 } // namespace shape 163 } // namespace mlir 164 165 void mlir::shape::registerBufferizableOpInterfaceExternalModels( 166 DialectRegistry ®istry) { 167 registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) { 168 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx); 169 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx); 170 }); 171 } 172