1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 20 namespace { 21 /// Bufferization of arith.constant. Replace with memref.get_global. 22 struct ConstantOpInterface 23 : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 24 arith::ConstantOp> { 25 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 26 const BufferizationState &state) const { 27 auto constantOp = cast<arith::ConstantOp>(op); 28 29 // Only ranked tensors are supported. 30 if (!constantOp.getType().isa<RankedTensorType>()) 31 return failure(); 32 33 // Only constants inside a module are supported. 34 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 35 if (!moduleOp) 36 return failure(); 37 38 // Create global memory segment and replace tensor with memref pointing to 39 // that memory segment. 40 FailureOr<memref::GlobalOp> globalOp = 41 getGlobalFor(constantOp, state.getOptions().bufferAlignment); 42 if (failed(globalOp)) 43 return failure(); 44 memref::GlobalOp globalMemref = globalOp.getValue(); 45 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 46 rewriter, op, globalMemref.type(), globalMemref.getName()); 47 48 return success(); 49 } 50 51 bool isWritable(Operation *op, Value value, 52 const BufferizationState &state) const { 53 // Memory locations returned by memref::GetGlobalOp may not be written to. 54 assert(value.isa<OpResult>()); 55 return false; 56 } 57 }; 58 59 struct IndexCastOpInterface 60 : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 61 arith::IndexCastOp> { 62 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 63 const BufferizationState &state) const { 64 return false; 65 } 66 67 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 68 const BufferizationState &state) const { 69 return false; 70 } 71 72 SmallVector<OpResult> 73 getAliasingOpResult(Operation *op, OpOperand &opOperand, 74 const BufferizationState &state) const { 75 return {op->getResult(0)}; 76 } 77 78 BufferRelation bufferRelation(Operation *op, OpResult opResult, 79 const BufferizationState &state) const { 80 return BufferRelation::Equivalent; 81 } 82 83 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 84 const BufferizationState &state) const { 85 auto castOp = cast<arith::IndexCastOp>(op); 86 87 Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); 88 auto sourceType = source.getType().cast<BaseMemRefType>(); 89 90 // Result type should have same layout and address space as the source type. 91 MemRefLayoutAttrInterface layout = {}; 92 if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) 93 layout = rankedMemRefType.getLayout(); 94 Type resultType = 95 getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(), 96 layout, sourceType.getMemorySpace()); 97 98 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType, 99 source); 100 return success(); 101 } 102 }; 103 104 /// Bufferization of arith.select. Just replace the operands. 105 struct SelectOpInterface 106 : public BufferizableOpInterface::ExternalModel<SelectOpInterface, 107 arith::SelectOp> { 108 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 109 const BufferizationState &state) const { 110 return false; 111 } 112 113 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 114 const BufferizationState &state) const { 115 return false; 116 } 117 118 SmallVector<OpResult> 119 getAliasingOpResult(Operation *op, OpOperand &opOperand, 120 const BufferizationState &state) const { 121 return {op->getOpResult(0) /*result*/}; 122 } 123 124 SmallVector<OpOperand *> 125 getAliasingOpOperand(Operation *op, OpResult opResult, 126 const BufferizationState &state) const { 127 return {&op->getOpOperand(1) /*true_value*/, 128 &op->getOpOperand(2) /*false_value*/}; 129 } 130 131 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 132 const BufferizationState &state) const { 133 auto selectOp = cast<arith::SelectOp>(op); 134 135 // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. 136 // TODO: It would be more efficient to copy the result of the `select` op 137 // instead of its OpOperands. In the worst case, 2 copies are inserted at 138 // the moment (one for each tensor). When copying the op result, only one 139 // copy would be needed. 140 Value trueBuffer = 141 *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); 142 Value falseBuffer = 143 *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); 144 replaceOpWithNewBufferizedOp<arith::SelectOp>( 145 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); 146 return success(); 147 } 148 149 BufferRelation bufferRelation(Operation *op, OpResult opResult, 150 const BufferizationState &state) const { 151 return BufferRelation::None; 152 } 153 }; 154 155 } // namespace 156 157 void mlir::arith::registerBufferizableOpInterfaceExternalModels( 158 DialectRegistry ®istry) { 159 registry.addOpInterface<ConstantOp, ConstantOpInterface>(); 160 registry.addOpInterface<IndexCastOp, IndexCastOpInterface>(); 161 registry.addOpInterface<SelectOp, SelectOpInterface>(); 162 } 163