1075e3fddSMatthias Springer //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// 2075e3fddSMatthias Springer // 3075e3fddSMatthias Springer // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4075e3fddSMatthias Springer // See https://llvm.org/LICENSE.txt for license information. 5075e3fddSMatthias Springer // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6075e3fddSMatthias Springer // 7075e3fddSMatthias Springer //===----------------------------------------------------------------------===// 8075e3fddSMatthias Springer 9075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" 10075e3fddSMatthias Springer #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 12075e3fddSMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" 13075e3fddSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 14075e3fddSMatthias Springer #include "mlir/IR/Dialect.h" 15075e3fddSMatthias Springer #include "mlir/IR/Operation.h" 16075e3fddSMatthias Springer 17dec8af70SRiver Riddle using namespace mlir; 18075e3fddSMatthias Springer using namespace mlir::bufferization; 19075e3fddSMatthias Springer 20075e3fddSMatthias Springer namespace { 21075e3fddSMatthias Springer /// Bufferization of arith.constant. Replace with memref.get_global. 22075e3fddSMatthias Springer struct ConstantOpInterface 23075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<ConstantOpInterface, 24075e3fddSMatthias Springer arith::ConstantOp> { 25075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 269597b16aSMatthias Springer BufferizationState &state) const { 27075e3fddSMatthias Springer auto constantOp = cast<arith::ConstantOp>(op); 28075e3fddSMatthias Springer 29075e3fddSMatthias Springer // Only ranked tensors are supported. 30075e3fddSMatthias Springer if (!constantOp.getType().isa<RankedTensorType>()) 31075e3fddSMatthias Springer return failure(); 32075e3fddSMatthias Springer 33075e3fddSMatthias Springer // Only constants inside a module are supported. 34075e3fddSMatthias Springer auto moduleOp = constantOp->getParentOfType<ModuleOp>(); 35075e3fddSMatthias Springer if (!moduleOp) 36075e3fddSMatthias Springer return failure(); 37075e3fddSMatthias Springer 38075e3fddSMatthias Springer // Create global memory segment and replace tensor with memref pointing to 39075e3fddSMatthias Springer // that memory segment. 40ab47418dSMatthias Springer FailureOr<memref::GlobalOp> globalOp = 41ab47418dSMatthias Springer getGlobalFor(constantOp, state.getOptions().bufferAlignment); 42ab47418dSMatthias Springer if (failed(globalOp)) 43ab47418dSMatthias Springer return failure(); 44ab47418dSMatthias Springer memref::GlobalOp globalMemref = globalOp.getValue(); 45075e3fddSMatthias Springer replaceOpWithNewBufferizedOp<memref::GetGlobalOp>( 46075e3fddSMatthias Springer rewriter, op, globalMemref.type(), globalMemref.getName()); 47075e3fddSMatthias Springer 48075e3fddSMatthias Springer return success(); 49075e3fddSMatthias Springer } 50075e3fddSMatthias Springer 51075e3fddSMatthias Springer bool isWritable(Operation *op, Value value, 529597b16aSMatthias Springer const AnalysisState &state) const { 53075e3fddSMatthias Springer // Memory locations returned by memref::GetGlobalOp may not be written to. 54075e3fddSMatthias Springer assert(value.isa<OpResult>()); 55075e3fddSMatthias Springer return false; 56075e3fddSMatthias Springer } 57075e3fddSMatthias Springer }; 58075e3fddSMatthias Springer 59075e3fddSMatthias Springer struct IndexCastOpInterface 60075e3fddSMatthias Springer : public BufferizableOpInterface::ExternalModel<IndexCastOpInterface, 61075e3fddSMatthias Springer arith::IndexCastOp> { 62075e3fddSMatthias Springer bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 639597b16aSMatthias Springer const AnalysisState &state) const { 64075e3fddSMatthias Springer return false; 65075e3fddSMatthias Springer } 66075e3fddSMatthias Springer 67075e3fddSMatthias Springer bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 689597b16aSMatthias Springer const AnalysisState &state) const { 69075e3fddSMatthias Springer return false; 70075e3fddSMatthias Springer } 71075e3fddSMatthias Springer 729597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 739597b16aSMatthias Springer const AnalysisState &state) const { 74585a8a32SMatthias Springer return {op->getResult(0)}; 75075e3fddSMatthias Springer } 76075e3fddSMatthias Springer 77075e3fddSMatthias Springer BufferRelation bufferRelation(Operation *op, OpResult opResult, 789597b16aSMatthias Springer const AnalysisState &state) const { 79075e3fddSMatthias Springer return BufferRelation::Equivalent; 80075e3fddSMatthias Springer } 81075e3fddSMatthias Springer 82075e3fddSMatthias Springer LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 839597b16aSMatthias Springer BufferizationState &state) const { 84075e3fddSMatthias Springer auto castOp = cast<arith::IndexCastOp>(op); 85*12e41d92SMatthias Springer auto resultTensorType = castOp.getType().cast<TensorType>(); 86075e3fddSMatthias Springer 87075e3fddSMatthias Springer Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); 88075e3fddSMatthias Springer auto sourceType = source.getType().cast<BaseMemRefType>(); 89075e3fddSMatthias Springer 90075e3fddSMatthias Springer // Result type should have same layout and address space as the source type. 91*12e41d92SMatthias Springer BaseMemRefType resultType; 92*12e41d92SMatthias Springer if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) { 93*12e41d92SMatthias Springer resultType = MemRefType::get( 94*12e41d92SMatthias Springer rankedMemRefType.getShape(), resultTensorType.getElementType(), 95*12e41d92SMatthias Springer rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); 96*12e41d92SMatthias Springer } else { 97*12e41d92SMatthias Springer auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>(); 98*12e41d92SMatthias Springer resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), 99*12e41d92SMatthias Springer unrankedMemrefType.getMemorySpace()); 100*12e41d92SMatthias Springer } 101075e3fddSMatthias Springer 1023c69bc4dSRiver Riddle replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType, 1033c69bc4dSRiver Riddle source); 104075e3fddSMatthias Springer return success(); 105075e3fddSMatthias Springer } 106075e3fddSMatthias Springer }; 107075e3fddSMatthias Springer 108dec8af70SRiver Riddle /// Bufferization of arith.select. Just replace the operands. 109dec8af70SRiver Riddle struct SelectOpInterface 110dec8af70SRiver Riddle : public BufferizableOpInterface::ExternalModel<SelectOpInterface, 111dec8af70SRiver Riddle arith::SelectOp> { 112dec8af70SRiver Riddle bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, 1139597b16aSMatthias Springer const AnalysisState &state) const { 114dec8af70SRiver Riddle return false; 115dec8af70SRiver Riddle } 116dec8af70SRiver Riddle 117dec8af70SRiver Riddle bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, 1189597b16aSMatthias Springer const AnalysisState &state) const { 119dec8af70SRiver Riddle return false; 120dec8af70SRiver Riddle } 121dec8af70SRiver Riddle 1229597b16aSMatthias Springer SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand, 1239597b16aSMatthias Springer const AnalysisState &state) const { 124585a8a32SMatthias Springer return {op->getOpResult(0) /*result*/}; 125dec8af70SRiver Riddle } 126dec8af70SRiver Riddle 127dec8af70SRiver Riddle SmallVector<OpOperand *> 128dec8af70SRiver Riddle getAliasingOpOperand(Operation *op, OpResult opResult, 1299597b16aSMatthias Springer const AnalysisState &state) const { 130dec8af70SRiver Riddle return {&op->getOpOperand(1) /*true_value*/, 131dec8af70SRiver Riddle &op->getOpOperand(2) /*false_value*/}; 132dec8af70SRiver Riddle } 133dec8af70SRiver Riddle 134dec8af70SRiver Riddle LogicalResult bufferize(Operation *op, RewriterBase &rewriter, 1359597b16aSMatthias Springer BufferizationState &state) const { 136dec8af70SRiver Riddle auto selectOp = cast<arith::SelectOp>(op); 1378b091419SMatthias Springer Location loc = selectOp.getLoc(); 138dec8af70SRiver Riddle 139dec8af70SRiver Riddle // `getBuffer` introduces copies if an OpOperand bufferizes out-of-place. 140dec8af70SRiver Riddle // TODO: It would be more efficient to copy the result of the `select` op 141dec8af70SRiver Riddle // instead of its OpOperands. In the worst case, 2 copies are inserted at 142dec8af70SRiver Riddle // the moment (one for each tensor). When copying the op result, only one 143dec8af70SRiver Riddle // copy would be needed. 144dec8af70SRiver Riddle Value trueBuffer = 145dec8af70SRiver Riddle *state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/); 146dec8af70SRiver Riddle Value falseBuffer = 147dec8af70SRiver Riddle *state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/); 1488b091419SMatthias Springer 1498b091419SMatthias Springer // The "true" and the "false" operands must have the same type. If the 1508b091419SMatthias Springer // buffers have different types, they differ only in their layout map. Cast 1518b091419SMatthias Springer // both of them to the most dynamic MemRef type. 1528b091419SMatthias Springer if (trueBuffer.getType() != falseBuffer.getType()) { 1538b091419SMatthias Springer auto trueType = trueBuffer.getType().cast<MemRefType>(); 1548b091419SMatthias Springer int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; 155*12e41d92SMatthias Springer SmallVector<int64_t> dynamicStrides(trueType.getRank(), 1568b091419SMatthias Springer ShapedType::kDynamicStrideOrOffset); 1578b091419SMatthias Springer AffineMap stridedLayout = makeStridedLinearLayoutMap( 1588b091419SMatthias Springer dynamicStrides, dynamicOffset, op->getContext()); 159*12e41d92SMatthias Springer auto castedType = 160*12e41d92SMatthias Springer MemRefType::get(trueType.getShape(), trueType.getElementType(), 161*12e41d92SMatthias Springer stridedLayout, trueType.getMemorySpaceAsInt()); 1628b091419SMatthias Springer trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer); 1638b091419SMatthias Springer falseBuffer = 1648b091419SMatthias Springer rewriter.create<memref::CastOp>(loc, castedType, falseBuffer); 1658b091419SMatthias Springer } 1668b091419SMatthias Springer 167dec8af70SRiver Riddle replaceOpWithNewBufferizedOp<arith::SelectOp>( 168dec8af70SRiver Riddle rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); 169dec8af70SRiver Riddle return success(); 170dec8af70SRiver Riddle } 171dec8af70SRiver Riddle 172dec8af70SRiver Riddle BufferRelation bufferRelation(Operation *op, OpResult opResult, 1739597b16aSMatthias Springer const AnalysisState &state) const { 174dec8af70SRiver Riddle return BufferRelation::None; 175dec8af70SRiver Riddle } 176dec8af70SRiver Riddle }; 177dec8af70SRiver Riddle 178075e3fddSMatthias Springer } // namespace 179075e3fddSMatthias Springer 180075e3fddSMatthias Springer void mlir::arith::registerBufferizableOpInterfaceExternalModels( 181075e3fddSMatthias Springer DialectRegistry ®istry) { 18277eee579SRiver Riddle registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) { 18377eee579SRiver Riddle ConstantOp::attachInterface<ConstantOpInterface>(*ctx); 18477eee579SRiver Riddle IndexCastOp::attachInterface<IndexCastOpInterface>(*ctx); 18577eee579SRiver Riddle SelectOp::attachInterface<SelectOpInterface>(*ctx); 18677eee579SRiver Riddle }); 187075e3fddSMatthias Springer } 188