1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" 10 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 13 #include "mlir/Dialect/Shape/IR/Shape.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/IR/PatternMatch.h" 17 18 using namespace mlir; 19 using namespace mlir::bufferization; 20 using namespace mlir::shape; 21 22 namespace mlir { 23 namespace shape { 24 namespace { 25 26 /// Bufferization of shape.assuming. 27 struct AssumingOpInterface 28 : public BufferizableOpInterface::ExternalModel<AssumingOpInterface, 29 shape::AssumingOp> { 30 SmallVector<OpOperand *> 31 getAliasingOpOperand(Operation *op, OpResult opResult, 32 const AnalysisState &state) const { 33 // AssumingOps do not have tensor OpOperands. The yielded value can be any 34 // SSA value that is in scope. To allow for use-def chain traversal through 35 // AssumingOps in the analysis, the corresponding yield value is considered 36 // to be aliasing with the result. 37 auto assumingOp = cast<shape::AssumingOp>(op); 38 size_t resultNum = std::distance(op->getOpResults().begin(), 39 llvm::find(op->getOpResults(), opResult)); 40 // TODO: Support multiple blocks. 41 assert(assumingOp.getDoRegion().getBlocks().size() == 1 && 42 "expected exactly 1 block"); 43 auto yieldOp = dyn_cast<shape::AssumingYieldOp>( 44 assumingOp.getDoRegion().front().getTerminator()); 45 assert(yieldOp && "expected shape.assuming_yield terminator"); 46 return {&yieldOp->getOpOperand(resultNum)}; 47 } 48 49 // TODO: For better bufferization results, this could return `true` only if 50 // there is a memory write in the region. 51 bool isMemoryWrite(Operation *op, OpResult opResult, 52 const AnalysisState &state) const { 53 // Similar to scf.if, results of this op are always considered memory writes 54 // in the analysis. This is a useful pattern for all ops that have tensor 55 // OpResults but no tensor OpOperands. By default, `isMemoryWrite` is 56 // implemented in terms of `bufferizesToMemoryWrite`, which does not work on 57 // ops without OpOperands. 58 return true; 59 } 60 61 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 62 const BufferizationOptions &options) const { 63 auto assumingOp = cast<shape::AssumingOp>(op); 64 assert(assumingOp.getDoRegion().getBlocks().size() == 1 && 65 "only 1 block supported"); 66 auto yieldOp = cast<shape::AssumingYieldOp>( 67 assumingOp.getDoRegion().front().getTerminator()); 68 69 // Create new op and move over region. 70 TypeRange newResultTypes(yieldOp.operands()); 71 auto newOp = rewriter.create<shape::AssumingOp>( 72 op->getLoc(), newResultTypes, assumingOp.getWitness()); 73 newOp.getDoRegion().takeBody(assumingOp.getRegion()); 74 75 // Update all uses of the old op. 76 rewriter.setInsertionPointAfter(newOp); 77 SmallVector<Value> newResults; 78 for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { 79 if (it.value().isa<TensorType>()) { 80 newResults.push_back(rewriter.create<bufferization::ToTensorOp>( 81 assumingOp.getLoc(), newOp->getResult(it.index()))); 82 } else { 83 newResults.push_back(newOp->getResult(it.index())); 84 } 85 } 86 87 // Replace old op. 88 rewriter.replaceOp(assumingOp, newResults); 89 90 return success(); 91 } 92 93 BufferRelation bufferRelation(Operation *op, OpResult opResult, 94 const AnalysisState &state) const { 95 return BufferRelation::Equivalent; 96 } 97 }; 98 99 /// Bufferization of shape.assuming_yield. Bufferized as part of their enclosing 100 /// ops, so this is for analysis only. 101 struct AssumingYieldOpInterface 102 : public BufferizableOpInterface::ExternalModel<AssumingYieldOpInterface, 103 shape::AssumingYieldOp> { 104 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 105 const AnalysisState &state) const { 106 return true; 107 } 108 109 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 110 const AnalysisState &state) const { 111 return false; 112 } 113 114 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 115 const AnalysisState &state) const { 116 assert(isa<shape::AssumingOp>(op->getParentOp()) && 117 "expected that parent is an AssumingOp"); 118 return {op->getParentOp()->getResult(opOperand.getOperandNumber())}; 119 } 120 121 bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, 122 const AnalysisState &state) const { 123 // Yield operands always bufferize inplace. Otherwise, an alloc + copy 124 // may be generated inside the block. We should not return/yield allocations 125 // when possible. 126 return true; 127 } 128 129 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 130 const BufferizationOptions &options) const { 131 auto yieldOp = cast<shape::AssumingYieldOp>(op); 132 SmallVector<Value> newResults; 133 for (Value value : yieldOp.operands()) { 134 if (value.getType().isa<TensorType>()) { 135 FailureOr<Value> buffer = getBuffer(rewriter, value, options); 136 if (failed(buffer)) 137 return failure(); 138 newResults.push_back(*buffer); 139 } else { 140 newResults.push_back(value); 141 } 142 } 143 replaceOpWithNewBufferizedOp<shape::AssumingYieldOp>(rewriter, op, 144 newResults); 145 return success(); 146 } 147 }; 148 149 } // namespace 150 } // namespace shape 151 } // namespace mlir 152 153 void mlir::shape::registerBufferizableOpInterfaceExternalModels( 154 DialectRegistry ®istry) { 155 registry.addExtension(+[](MLIRContext *ctx, shape::ShapeDialect *dialect) { 156 shape::AssumingOp::attachInterface<AssumingOpInterface>(*ctx); 157 shape::AssumingYieldOp::attachInterface<AssumingYieldOpInterface>(*ctx); 158 }); 159 } 160