//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" using namespace mlir; using namespace mlir::bufferization; namespace { /// Bufferization of arith.constant. Replace with memref.get_global. struct ConstantOpInterface : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto constantOp = cast(op); // TODO: Implement memory space for this op. E.g., by adding a memory_space // attribute to ConstantOp. if (options.defaultMemorySpace != static_cast(0)) return op->emitError("memory space not implemented yet"); // Only ranked tensors are supported. if (!constantOp.getType().isa()) return failure(); // Only constants inside a module are supported. auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) return failure(); // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = getGlobalFor(constantOp, options.bufferAlignment); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; replaceOpWithNewBufferizedOp( rewriter, op, globalMemref.getType(), globalMemref.getName()); return success(); } bool isWritable(Operation *op, Value value, const AnalysisState &state) const { // Memory locations returned by memref::GetGlobalOp may not be written to. assert(value.isa()); return false; } }; struct IndexCastOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getResult(0)}; } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::Equivalent; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto castOp = cast(op); auto resultTensorType = castOp.getType().cast(); FailureOr source = getBuffer(rewriter, castOp.getIn(), options); if (failed(source)) return failure(); auto sourceType = source->getType().cast(); // Result type should have same layout and address space as the source type. BaseMemRefType resultType; if (auto rankedMemRefType = sourceType.dyn_cast()) { resultType = MemRefType::get( rankedMemRefType.getShape(), resultTensorType.getElementType(), rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace()); } else { auto unrankedMemrefType = sourceType.cast(); resultType = UnrankedMemRefType::get(resultTensorType.getElementType(), unrankedMemrefType.getMemorySpace()); } replaceOpWithNewBufferizedOp(rewriter, op, resultType, *source); return success(); } }; /// Bufferization of arith.select. Just replace the operands. struct SelectOpInterface : public BufferizableOpInterface::ExternalModel { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; } SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return {op->getOpResult(0) /*result*/}; } SmallVector getAliasingOpOperand(Operation *op, OpResult opResult, const AnalysisState &state) const { return {&op->getOpOperand(1) /*true_value*/, &op->getOpOperand(2) /*false_value*/}; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto selectOp = cast(op); Location loc = selectOp.getLoc(); // TODO: It would be more efficient to copy the result of the `select` op // instead of its OpOperands. In the worst case, 2 copies are inserted at // the moment (one for each tensor). When copying the op result, only one // copy would be needed. FailureOr maybeTrueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options); FailureOr maybeFalseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options); if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer)) return failure(); Value trueBuffer = *maybeTrueBuffer; Value falseBuffer = *maybeFalseBuffer; BaseMemRefType trueType = trueBuffer.getType().cast(); BaseMemRefType falseType = falseBuffer.getType().cast(); if (trueType.getMemorySpaceAsInt() != falseType.getMemorySpaceAsInt()) return op->emitError("inconsistent memory space on true/false operands"); // The "true" and the "false" operands must have the same type. If the // buffers have different types, they differ only in their layout map. Cast // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { auto trueType = trueBuffer.getType().cast(); int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; SmallVector dynamicStrides(trueType.getRank(), ShapedType::kDynamicStrideOrOffset); AffineMap stridedLayout = makeStridedLinearLayoutMap( dynamicStrides, dynamicOffset, op->getContext()); auto castedType = MemRefType::get(trueType.getShape(), trueType.getElementType(), stridedLayout, trueType.getMemorySpaceAsInt()); trueBuffer = rewriter.create(loc, castedType, trueBuffer); falseBuffer = rewriter.create(loc, castedType, falseBuffer); } replaceOpWithNewBufferizedOp( rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer); return success(); } BufferRelation bufferRelation(Operation *op, OpResult opResult, const AnalysisState &state) const { return BufferRelation::None; } }; } // namespace void mlir::arith::registerBufferizableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, ArithmeticDialect *dialect) { ConstantOp::attachInterface(*ctx); IndexCastOp::attachInterface(*ctx); SelectOp::attachInterface(*ctx); }); }