1075e3fddSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2075e3fddSMatthias Springer // 3075e3fddSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4075e3fddSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5075e3fddSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6075e3fddSMatthias Springer // 7075e3fddSMatthias Springer //===----------------------------------------------------------------------===// 8075e3fddSMatthias Springer 9075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13075e3fddSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 14075e3fddSMatthias Springer #include "mlir/IR/Dialect.h" 15075e3fddSMatthias Springer #include "mlir/IR/Operation.h" 16075e3fddSMatthias Springer 17dec8af70SRiver Riddle using namespace mlir; 18075e3fddSMatthias Springer using namespace mlir::bufferization; 19075e3fddSMatthias Springer 20075e3fddSMatthias Springer namespace { 21075e3fddSMatthias Springer /// Bufferization of arith.constant. Replace with memref.get_global. 22075e3fddSMatthias Springer struct ConstantOpInterface 23075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 24075e3fddSMatthias Springer arith::ConstantOp> { 25075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 26075e3fddSMatthias Springer const BufferizationState &state) const { 27075e3fddSMatthias Springer auto constantOp = cast<arith::ConstantOp>(op); 28075e3fddSMatthias Springer 29075e3fddSMatthias Springer // Only ranked tensors are supported. 30075e3fddSMatthias Springer if (!constantOp.getType().isa<RankedTensorType>()) 31075e3fddSMatthias Springer return failure(); 32075e3fddSMatthias Springer 33075e3fddSMatthias Springer // Only constants inside a module are supported. 34075e3fddSMatthias Springer auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 35075e3fddSMatthias Springer if (!moduleOp) 36075e3fddSMatthias Springer return failure(); 37075e3fddSMatthias Springer 38075e3fddSMatthias Springer // Create global memory segment and replace tensor with memref pointing to 39075e3fddSMatthias Springer // that memory segment. 40ab47418dSMatthias Springer FailureOr<memref::GlobalOp> globalOp = 41ab47418dSMatthias Springer getGlobalFor(constantOp, state.getOptions().bufferAlignment); 42ab47418dSMatthias Springer if (failed(globalOp)) 43ab47418dSMatthias Springer return failure(); 44ab47418dSMatthias Springer memref::GlobalOp globalMemref = globalOp.getValue(); 45075e3fddSMatthias Springer replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 46075e3fddSMatthias Springer rewriter, op, globalMemref.type(), globalMemref.getName()); 47075e3fddSMatthias Springer 48075e3fddSMatthias Springer return success(); 49075e3fddSMatthias Springer } 50075e3fddSMatthias Springer 51075e3fddSMatthias Springer bool isWritable(Operation *op, Value value, 52075e3fddSMatthias Springer const BufferizationState &state) const { 53075e3fddSMatthias Springer // Memory locations returned by memref::GetGlobalOp may not be written to. 54075e3fddSMatthias Springer assert(value.isa<OpResult>()); 55075e3fddSMatthias Springer return false; 56075e3fddSMatthias Springer } 57075e3fddSMatthias Springer }; 58075e3fddSMatthias Springer 59075e3fddSMatthias Springer struct IndexCastOpInterface 60075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 61075e3fddSMatthias Springer arith::IndexCastOp> { 62075e3fddSMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 63075e3fddSMatthias Springer const BufferizationState &state) const { 64075e3fddSMatthias Springer return false; 65075e3fddSMatthias Springer } 66075e3fddSMatthias Springer 67075e3fddSMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 68075e3fddSMatthias Springer const BufferizationState &state) const { 69075e3fddSMatthias Springer return false; 70075e3fddSMatthias Springer } 71075e3fddSMatthias Springer 72075e3fddSMatthias Springer OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 73075e3fddSMatthias Springer const BufferizationState &state) const { 74075e3fddSMatthias Springer return op->getResult(0); 75075e3fddSMatthias Springer } 76075e3fddSMatthias Springer 77075e3fddSMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 78075e3fddSMatthias Springer const BufferizationState &state) const { 79075e3fddSMatthias Springer return BufferRelation::Equivalent; 80075e3fddSMatthias Springer } 81075e3fddSMatthias Springer 82075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 83075e3fddSMatthias Springer const BufferizationState &state) const { 84075e3fddSMatthias Springer auto castOp = cast<arith::IndexCastOp>(op); 85075e3fddSMatthias Springer 86075e3fddSMatthias Springer Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); 87075e3fddSMatthias Springer auto sourceType = source.getType().cast<BaseMemRefType>(); 88075e3fddSMatthias Springer 89075e3fddSMatthias Springer // Result type should have same layout and address space as the source type. 90075e3fddSMatthias Springer MemRefLayoutAttrInterface layout = {}; 91075e3fddSMatthias Springer if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) 92075e3fddSMatthias Springer layout = rankedMemRefType.getLayout(); 93075e3fddSMatthias Springer Type resultType = 94075e3fddSMatthias Springer getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(), 95075e3fddSMatthias Springer layout, sourceType.getMemorySpace()); 96075e3fddSMatthias Springer 97*3c69bc4dSRiver Riddle replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType, 98*3c69bc4dSRiver Riddle source); 99075e3fddSMatthias Springer return success(); 100075e3fddSMatthias Springer } 101075e3fddSMatthias Springer }; 102075e3fddSMatthias Springer 103dec8af70SRiver Riddle /// Bufferization of arith.select. Just replace the operands. 104dec8af70SRiver Riddle struct SelectOpInterface 105dec8af70SRiver Riddle : public BufferizableOpInterface::ExternalModel<SelectOpInterface, 106dec8af70SRiver Riddle arith::SelectOp> { 107dec8af70SRiver Riddle bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 108dec8af70SRiver Riddle const BufferizationState &state) const { 109dec8af70SRiver Riddle return false; 110dec8af70SRiver Riddle } 111dec8af70SRiver Riddle 112dec8af70SRiver Riddle bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 113dec8af70SRiver Riddle const BufferizationState &state) const { 114dec8af70SRiver Riddle return false; 115dec8af70SRiver Riddle } 116dec8af70SRiver Riddle 117dec8af70SRiver Riddle OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, 118dec8af70SRiver Riddle const BufferizationState &state) const { 119dec8af70SRiver Riddle return op->getOpResult(0) /*result*/; 120dec8af70SRiver Riddle } 121dec8af70SRiver Riddle 122dec8af70SRiver Riddle SmallVector<OpOperand *> 123dec8af70SRiver Riddle getAliasingOpOperand(Operation *op, OpResult opResult, 124dec8af70SRiver Riddle const BufferizationState &state) const { 125dec8af70SRiver Riddle return {&op->getOpOperand(1) /*true_value*/, 126dec8af70SRiver Riddle &op->getOpOperand(2) /*false_value*/}; 127dec8af70SRiver Riddle } 128dec8af70SRiver Riddle 129dec8af70SRiver Riddle LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 130dec8af70SRiver Riddle const BufferizationState &state) const { 131dec8af70SRiver Riddle auto selectOp = cast<arith::SelectOp>(op); 132dec8af70SRiver Riddle 133dec8af70SRiver Riddle // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. 134dec8af70SRiver Riddle // TODO: It would be more efficient to copy the result of the `select` op 135dec8af70SRiver Riddle // instead of its OpOperands. In the worst case, 2 copies are inserted at 136dec8af70SRiver Riddle // the moment (one for each tensor). When copying the op result, only one 137dec8af70SRiver Riddle // copy would be needed. 138dec8af70SRiver Riddle Value trueBuffer = 139dec8af70SRiver Riddle *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); 140dec8af70SRiver Riddle Value falseBuffer = 141dec8af70SRiver Riddle *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); 142dec8af70SRiver Riddle replaceOpWithNewBufferizedOp<arith::SelectOp>( 143dec8af70SRiver Riddle rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); 144dec8af70SRiver Riddle return success(); 145dec8af70SRiver Riddle } 146dec8af70SRiver Riddle 147dec8af70SRiver Riddle BufferRelation bufferRelation(Operation *op, OpResult opResult, 148dec8af70SRiver Riddle const BufferizationState &state) const { 149dec8af70SRiver Riddle return BufferRelation::None; 150dec8af70SRiver Riddle } 151dec8af70SRiver Riddle }; 152dec8af70SRiver Riddle 153075e3fddSMatthias Springer } // namespace 154075e3fddSMatthias Springer 155075e3fddSMatthias Springer void mlir::arith::registerBufferizableOpInterfaceExternalModels( 156075e3fddSMatthias Springer DialectRegistry ®istry) { 157075e3fddSMatthias Springer registry.addOpInterface<ConstantOp, ConstantOpInterface>(); 158075e3fddSMatthias Springer registry.addOpInterface<IndexCastOp, IndexCastOpInterface>(); 159dec8af70SRiver Riddle registry.addOpInterface<SelectOp, SelectOpInterface>(); 160075e3fddSMatthias Springer } 161