1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir::bufferization; 18 19 namespace mlir { 20 namespace arith { 21 namespace { 22 23 /// Bufferization of arith.constant. Replace with memref.get_global. 24 struct ConstantOpInterface 25 : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 26 arith::ConstantOp> { 27 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 28 const BufferizationState &state) const { 29 auto constantOp = cast<arith::ConstantOp>(op); 30 31 // Only ranked tensors are supported. 32 if (!constantOp.getType().isa<RankedTensorType>()) 33 return failure(); 34 35 // Only constants inside a module are supported. 36 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 37 if (!moduleOp) 38 return failure(); 39 40 // Create global memory segment and replace tensor with memref pointing to 41 // that memory segment. 42 GlobalCreator globalCreator(moduleOp); 43 auto globalMemref = globalCreator.getGlobalFor(constantOp); 44 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 45 rewriter, op, globalMemref.type(), globalMemref.getName()); 46 47 return success(); 48 } 49 50 bool isWritable(Operation *op, Value value, 51 const BufferizationState &state) const { 52 // Memory locations returned by memref::GetGlobalOp may not be written to. 53 assert(value.isa<OpResult>()); 54 return false; 55 } 56 }; 57 58 struct IndexCastOpInterface 59 : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 60 arith::IndexCastOp> { 61 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 62 const BufferizationState &state) const { 63 return false; 64 } 65 66 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 67 const BufferizationState &state) const { 68 return false; 69 } 70 71 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 72 const BufferizationState &state) const { 73 return op->getResult(0); 74 } 75 76 BufferRelation bufferRelation(Operation *op, OpResult opResult, 77 const BufferizationState &state) const { 78 return BufferRelation::Equivalent; 79 } 80 81 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 82 const BufferizationState &state) const { 83 auto castOp = cast<arith::IndexCastOp>(op); 84 85 Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); 86 auto sourceType = source.getType().cast<BaseMemRefType>(); 87 88 // Result type should have same layout and address space as the source type. 89 MemRefLayoutAttrInterface layout = {}; 90 if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) 91 layout = rankedMemRefType.getLayout(); 92 Type resultType = 93 getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(), 94 layout, sourceType.getMemorySpace()); 95 96 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source, 97 resultType); 98 return success(); 99 } 100 }; 101 102 } // namespace 103 } // namespace arith 104 } // namespace mlir 105 106 void mlir::arith::registerBufferizableOpInterfaceExternalModels( 107 DialectRegistry ®istry) { 108 registry.addOpInterface<ConstantOp, ConstantOpInterface>(); 109 registry.addOpInterface<IndexCastOp, IndexCastOpInterface>(); 110 } 111