1 //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13 #include "mlir/Dialect/MemRef/IR/MemRef.h" 14 #include "mlir/IR/Dialect.h" 15 #include "mlir/IR/Operation.h" 16 17 using namespace mlir::bufferization; 18 19 namespace mlir { 20 namespace arith { 21 namespace { 22 23 /// Bufferization of arith.constant. Replace with memref.get_global. 24 struct ConstantOpInterface 25 : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 26 arith::ConstantOp> { 27 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 28 const BufferizationState &state) const { 29 auto constantOp = cast<arith::ConstantOp>(op); 30 31 // Only ranked tensors are supported. 32 if (!constantOp.getType().isa<RankedTensorType>()) 33 return failure(); 34 35 // Only constants inside a module are supported. 36 auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 37 if (!moduleOp) 38 return failure(); 39 40 // Create global memory segment and replace tensor with memref pointing to 41 // that memory segment. 42 FailureOr<memref::GlobalOp> globalOp = 43 getGlobalFor(constantOp, state.getOptions().bufferAlignment); 44 if (failed(globalOp)) 45 return failure(); 46 memref::GlobalOp globalMemref = globalOp.getValue(); 47 replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 48 rewriter, op, globalMemref.type(), globalMemref.getName()); 49 50 return success(); 51 } 52 53 bool isWritable(Operation *op, Value value, 54 const BufferizationState &state) const { 55 // Memory locations returned by memref::GetGlobalOp may not be written to. 56 assert(value.isa<OpResult>()); 57 return false; 58 } 59 }; 60 61 struct IndexCastOpInterface 62 : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 63 arith::IndexCastOp> { 64 bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 65 const BufferizationState &state) const { 66 return false; 67 } 68 69 bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 70 const BufferizationState &state) const { 71 return false; 72 } 73 74 OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 75 const BufferizationState &state) const { 76 return op->getResult(0); 77 } 78 79 BufferRelation bufferRelation(Operation *op, OpResult opResult, 80 const BufferizationState &state) const { 81 return BufferRelation::Equivalent; 82 } 83 84 LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 85 const BufferizationState &state) const { 86 auto castOp = cast<arith::IndexCastOp>(op); 87 88 Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); 89 auto sourceType = source.getType().cast<BaseMemRefType>(); 90 91 // Result type should have same layout and address space as the source type. 92 MemRefLayoutAttrInterface layout = {}; 93 if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) 94 layout = rankedMemRefType.getLayout(); 95 Type resultType = 96 getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(), 97 layout, sourceType.getMemorySpace()); 98 99 replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, source, 100 resultType); 101 return success(); 102 } 103 }; 104 105 } // namespace 106 } // namespace arith 107 } // namespace mlir 108 109 void mlir::arith::registerBufferizableOpInterfaceExternalModels( 110 DialectRegistry ®istry) { 111 registry.addOpInterface<ConstantOp, ConstantOpInterface>(); 112 registry.addOpInterface<IndexCastOp, IndexCastOpInterface>(); 113 } 114