1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir; 18 using namespace mlir::bufferization; 19 20 namespace { 21 /// Bufferization of arith.constant. Replace with memref.get_global. 22 struct ConstantOpInterface 23 : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 24 arith::ConstantOp> { 25 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 26 const BufferizationOptions &options) const { 27 auto constantOp = cast<arith::ConstantOp>(op); 28 29 // TODO: Implement memory space for this op. E.g., by adding a memory_space 30 // attribute to ConstantOp. 31 if (options.defaultMemorySpace != static_cast<unsigned>(0)) 32 return op->emitError("memory space not implemented yet"); 33 34 // Only ranked tensors are supported. 35 if (!constantOp.getType().isa<RankedTensorType>()) 36 return failure(); 37 38 // Only constants inside a module are supported. 39 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 40 if (!moduleOp) 41 return failure(); 42 43 // Create global memory segment and replace tensor with memref pointing to 44 // that memory segment. 45 FailureOr<memref::GlobalOp> globalOp = 46 getGlobalFor(constantOp, options.bufferAlignment); 47 if (failed(globalOp)) 48 return failure(); 49 memref::GlobalOp globalMemref = *globalOp; 50 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 51 rewriter, op, globalMemref.getType(), globalMemref.getName()); 52 53 return success(); 54 } 55 56 bool isWritable(Operation *op, Value value, 57 const AnalysisState &state) const { 58 // Memory locations returned by memref::GetGlobalOp may not be written to. 59 assert(value.isa<OpResult>()); 60 return false; 61 } 62 }; 63 64 struct IndexCastOpInterface 65 : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 66 arith::IndexCastOp> { 67 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 68 const AnalysisState &state) const { 69 return false; 70 } 71 72 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 73 const AnalysisState &state) const { 74 return false; 75 } 76 77 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 78 const AnalysisState &state) const { 79 return {op->getResult(0)}; 80 } 81 82 BufferRelation bufferRelation(Operation *op, OpResult opResult, 83 const AnalysisState &state) const { 84 return BufferRelation::Equivalent; 85 } 86 87 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 88 const BufferizationOptions &options) const { 89 auto castOp = cast<arith::IndexCastOp>(op); 90 auto resultTensorType = castOp.getType().cast<TensorType>(); 91 92 FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options); 93 if (failed(source)) 94 return failure(); 95 auto sourceType = source->getType().cast<BaseMemRefType>(); 96 97 // Result type should have same layout and address space as the source type. 98 BaseMemRefType resultType; 99 if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) { 100 resultType = MemRefType::get( 101 rankedMemRefType.getShape(), resultTensorType.getElementType(), 102 rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); 103 } else { 104 auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>(); 105 resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), 106 unrankedMemrefType.getMemorySpace()); 107 } 108 109 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType, 110 *source); 111 return success(); 112 } 113 }; 114 115 /// Bufferization of arith.select. Just replace the operands. 116 struct SelectOpInterface 117 : public BufferizableOpInterface::ExternalModel<SelectOpInterface, 118 arith::SelectOp> { 119 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 120 const AnalysisState &state) const { 121 return false; 122 } 123 124 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 125 const AnalysisState &state) const { 126 return false; 127 } 128 129 SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 130 const AnalysisState &state) const { 131 return {op->getOpResult(0) /*result*/}; 132 } 133 134 SmallVector<OpOperand *> 135 getAliasingOpOperand(Operation *op, OpResult opResult, 136 const AnalysisState &state) const { 137 return {&op->getOpOperand(1) /*true_value*/, 138 &op->getOpOperand(2) /*false_value*/}; 139 } 140 141 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 142 const BufferizationOptions &options) const { 143 auto selectOp = cast<arith::SelectOp>(op); 144 Location loc = selectOp.getLoc(); 145 146 // TODO: It would be more efficient to copy the result of the `select` op 147 // instead of its OpOperands. In the worst case, 2 copies are inserted at 148 // the moment (one for each tensor). When copying the op result, only one 149 // copy would be needed. 150 FailureOr<Value> maybeTrueBuffer = 151 getBuffer(rewriter, selectOp.getTrueValue(), options); 152 FailureOr<Value> maybeFalseBuffer = 153 getBuffer(rewriter, selectOp.getFalseValue(), options); 154 if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer)) 155 return failure(); 156 Value trueBuffer = *maybeTrueBuffer; 157 Value falseBuffer = *maybeFalseBuffer; 158 BaseMemRefType trueType = trueBuffer.getType().cast<BaseMemRefType>(); 159 BaseMemRefType falseType = falseBuffer.getType().cast<BaseMemRefType>(); 160 if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt()) 161 return op->emitError("inconsistent memory space on true/false operands"); 162 163 // The "true" and the "false" operands must have the same type. If the 164 // buffers have different types, they differ only in their layout map. Cast 165 // both of them to the most dynamic MemRef type. 166 if (trueBuffer.getType() != falseBuffer.getType()) { 167 auto trueType = trueBuffer.getType().cast<MemRefType>(); 168 int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 169 SmallVector<int64_t> dynamicStrides(trueType.getRank(), 170 ShapedType::kDynamicStrideOrOffset); 171 AffineMap stridedLayout = makeStridedLinearLayoutMap( 172 dynamicStrides, dynamicOffset, op->getContext()); 173 auto castedType = 174 MemRefType::get(trueType.getShape(), trueType.getElementType(), 175 stridedLayout, trueType.getMemorySpaceAsInt()); 176 trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer); 177 falseBuffer = 178 rewriter.create<memref::CastOp>(loc, castedType, falseBuffer); 179 } 180 181 replaceOpWithNewBufferizedOp<arith::SelectOp>( 182 rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); 183 return success(); 184 } 185 186 BufferRelation bufferRelation(Operation *op, OpResult opResult, 187 const AnalysisState &state) const { 188 return BufferRelation::None; 189 } 190 }; 191 192 } // namespace 193 194 void mlir::arith::registerBufferizableOpInterfaceExternalModels( 195 DialectRegistry ®istry) { 196 registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) { 197 ConstantOp::attachInterface<ConstantOpInterface>(*ctx); 198 IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx); 199 SelectOp::attachInterface<SelectOpInterface>(*ctx); 200 }); 201 } 202