1075e3fddSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2075e3fddSMatthias Springer // 3075e3fddSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4075e3fddSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5075e3fddSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6075e3fddSMatthias Springer // 7075e3fddSMatthias Springer //===----------------------------------------------------------------------===// 8075e3fddSMatthias Springer 9075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13075e3fddSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 14075e3fddSMatthias Springer #include "mlir/IR/Dialect.h" 15075e3fddSMatthias Springer #include "mlir/IR/Operation.h" 16075e3fddSMatthias Springer 17dec8af70SRiver Riddle using namespace mlir; 18075e3fddSMatthias Springer using namespace mlir::bufferization; 19075e3fddSMatthias Springer 20075e3fddSMatthias Springer namespace { 21075e3fddSMatthias Springer /// Bufferization of arith.constant. Replace with memref.get_global. 22075e3fddSMatthias Springer struct ConstantOpInterface 23075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 24075e3fddSMatthias Springer arith::ConstantOp> { 25075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 26075e3fddSMatthias Springer const BufferizationState &state) const { 27075e3fddSMatthias Springer auto constantOp = cast<arith::ConstantOp>(op); 28075e3fddSMatthias Springer 29075e3fddSMatthias Springer // Only ranked tensors are supported. 30075e3fddSMatthias Springer if (!constantOp.getType().isa<RankedTensorType>()) 31075e3fddSMatthias Springer return failure(); 32075e3fddSMatthias Springer 33075e3fddSMatthias Springer // Only constants inside a module are supported. 34075e3fddSMatthias Springer auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 35075e3fddSMatthias Springer if (!moduleOp) 36075e3fddSMatthias Springer return failure(); 37075e3fddSMatthias Springer 38075e3fddSMatthias Springer // Create global memory segment and replace tensor with memref pointing to 39075e3fddSMatthias Springer // that memory segment. 40ab47418dSMatthias Springer FailureOr<memref::GlobalOp> globalOp = 41ab47418dSMatthias Springer getGlobalFor(constantOp, state.getOptions().bufferAlignment); 42ab47418dSMatthias Springer if (failed(globalOp)) 43ab47418dSMatthias Springer return failure(); 44ab47418dSMatthias Springer memref::GlobalOp globalMemref = globalOp.getValue(); 45075e3fddSMatthias Springer replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 46075e3fddSMatthias Springer rewriter, op, globalMemref.type(), globalMemref.getName()); 47075e3fddSMatthias Springer 48075e3fddSMatthias Springer return success(); 49075e3fddSMatthias Springer } 50075e3fddSMatthias Springer 51075e3fddSMatthias Springer bool isWritable(Operation *op, Value value, 52075e3fddSMatthias Springer const BufferizationState &state) const { 53075e3fddSMatthias Springer // Memory locations returned by memref::GetGlobalOp may not be written to. 54075e3fddSMatthias Springer assert(value.isa<OpResult>()); 55075e3fddSMatthias Springer return false; 56075e3fddSMatthias Springer } 57075e3fddSMatthias Springer }; 58075e3fddSMatthias Springer 59075e3fddSMatthias Springer struct IndexCastOpInterface 60075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 61075e3fddSMatthias Springer arith::IndexCastOp> { 62075e3fddSMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 63075e3fddSMatthias Springer const BufferizationState &state) const { 64075e3fddSMatthias Springer return false; 65075e3fddSMatthias Springer } 66075e3fddSMatthias Springer 67075e3fddSMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 68075e3fddSMatthias Springer const BufferizationState &state) const { 69075e3fddSMatthias Springer return false; 70075e3fddSMatthias Springer } 71075e3fddSMatthias Springer 72*585a8a32SMatthias Springer SmallVector<OpResult> 73*585a8a32SMatthias Springer getAliasingOpResult(Operation *op, OpOperand &opOperand, 74075e3fddSMatthias Springer const BufferizationState &state) const { 75*585a8a32SMatthias Springer return {op->getResult(0)}; 76075e3fddSMatthias Springer } 77075e3fddSMatthias Springer 78075e3fddSMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 79075e3fddSMatthias Springer const BufferizationState &state) const { 80075e3fddSMatthias Springer return BufferRelation::Equivalent; 81075e3fddSMatthias Springer } 82075e3fddSMatthias Springer 83075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 84075e3fddSMatthias Springer const BufferizationState &state) const { 85075e3fddSMatthias Springer auto castOp = cast<arith::IndexCastOp>(op); 86075e3fddSMatthias Springer 87075e3fddSMatthias Springer Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); 88075e3fddSMatthias Springer auto sourceType = source.getType().cast<BaseMemRefType>(); 89075e3fddSMatthias Springer 90075e3fddSMatthias Springer // Result type should have same layout and address space as the source type. 91075e3fddSMatthias Springer MemRefLayoutAttrInterface layout = {}; 92075e3fddSMatthias Springer if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) 93075e3fddSMatthias Springer layout = rankedMemRefType.getLayout(); 94075e3fddSMatthias Springer Type resultType = 95075e3fddSMatthias Springer getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(), 96075e3fddSMatthias Springer layout, sourceType.getMemorySpace()); 97075e3fddSMatthias Springer 983c69bc4dSRiver Riddle replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType, 993c69bc4dSRiver Riddle source); 100075e3fddSMatthias Springer return success(); 101075e3fddSMatthias Springer } 102075e3fddSMatthias Springer }; 103075e3fddSMatthias Springer 104dec8af70SRiver Riddle /// Bufferization of arith.select. Just replace the operands. 105dec8af70SRiver Riddle struct SelectOpInterface 106dec8af70SRiver Riddle : public BufferizableOpInterface::ExternalModel<SelectOpInterface, 107dec8af70SRiver Riddle arith::SelectOp> { 108dec8af70SRiver Riddle bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 109dec8af70SRiver Riddle const BufferizationState &state) const { 110dec8af70SRiver Riddle return false; 111dec8af70SRiver Riddle } 112dec8af70SRiver Riddle 113dec8af70SRiver Riddle bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 114dec8af70SRiver Riddle const BufferizationState &state) const { 115dec8af70SRiver Riddle return false; 116dec8af70SRiver Riddle } 117dec8af70SRiver Riddle 118*585a8a32SMatthias Springer SmallVector<OpResult> 119*585a8a32SMatthias Springer getAliasingOpResult(Operation *op, OpOperand &opOperand, 120dec8af70SRiver Riddle const BufferizationState &state) const { 121*585a8a32SMatthias Springer return {op->getOpResult(0) /*result*/}; 122dec8af70SRiver Riddle } 123dec8af70SRiver Riddle 124dec8af70SRiver Riddle SmallVector<OpOperand *> 125dec8af70SRiver Riddle getAliasingOpOperand(Operation *op, OpResult opResult, 126dec8af70SRiver Riddle const BufferizationState &state) const { 127dec8af70SRiver Riddle return {&op->getOpOperand(1) /*true_value*/, 128dec8af70SRiver Riddle &op->getOpOperand(2) /*false_value*/}; 129dec8af70SRiver Riddle } 130dec8af70SRiver Riddle 131dec8af70SRiver Riddle LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 132dec8af70SRiver Riddle const BufferizationState &state) const { 133dec8af70SRiver Riddle auto selectOp = cast<arith::SelectOp>(op); 134dec8af70SRiver Riddle 135dec8af70SRiver Riddle // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. 136dec8af70SRiver Riddle // TODO: It would be more efficient to copy the result of the `select` op 137dec8af70SRiver Riddle // instead of its OpOperands. In the worst case, 2 copies are inserted at 138dec8af70SRiver Riddle // the moment (one for each tensor). When copying the op result, only one 139dec8af70SRiver Riddle // copy would be needed. 140dec8af70SRiver Riddle Value trueBuffer = 141dec8af70SRiver Riddle *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); 142dec8af70SRiver Riddle Value falseBuffer = 143dec8af70SRiver Riddle *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); 144dec8af70SRiver Riddle replaceOpWithNewBufferizedOp<arith::SelectOp>( 145dec8af70SRiver Riddle rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); 146dec8af70SRiver Riddle return success(); 147dec8af70SRiver Riddle } 148dec8af70SRiver Riddle 149dec8af70SRiver Riddle BufferRelation bufferRelation(Operation *op, OpResult opResult, 150dec8af70SRiver Riddle const BufferizationState &state) const { 151dec8af70SRiver Riddle return BufferRelation::None; 152dec8af70SRiver Riddle } 153dec8af70SRiver Riddle }; 154dec8af70SRiver Riddle 155075e3fddSMatthias Springer } // namespace 156075e3fddSMatthias Springer 157075e3fddSMatthias Springer void mlir::arith::registerBufferizableOpInterfaceExternalModels( 158075e3fddSMatthias Springer DialectRegistry ®istry) { 159075e3fddSMatthias Springer registry.addOpInterface<ConstantOp, ConstantOpInterface>(); 160075e3fddSMatthias Springer registry.addOpInterface<IndexCastOp, IndexCastOpInterface>(); 161dec8af70SRiver Riddle registry.addOpInterface<SelectOp, SelectOpInterface>(); 162075e3fddSMatthias Springer } 163