157470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 257470abcSAlexander Belyaev // 357470abcSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 457470abcSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information. 557470abcSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 657470abcSAlexander Belyaev // 757470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 857470abcSAlexander Belyaev 9ffdbecccSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h" 10*3474d10eSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" 1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 12eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h" 13eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h" 14e07a7fd5SMatthias Springer #include "mlir/IR/FunctionInterfaces.h" 1591072b74SButygin #include "mlir/Transforms/InliningUtils.h" 1657470abcSAlexander Belyaev 1757470abcSAlexander Belyaev using namespace mlir; 1857470abcSAlexander Belyaev using namespace mlir::bufferization; 1957470abcSAlexander Belyaev 2057470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc" 2157470abcSAlexander Belyaev 22e07a7fd5SMatthias Springer /// Attribute name used to mark function arguments who's buffers can be written 23e07a7fd5SMatthias Springer /// to during One-Shot Module Bufferize. 24e07a7fd5SMatthias Springer constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName; 25e07a7fd5SMatthias Springer 26e07a7fd5SMatthias Springer /// Attribute name used to mark the bufferization layout for region arguments 27e07a7fd5SMatthias Springer /// during One-Shot Module Bufferize. 28e07a7fd5SMatthias Springer constexpr const ::llvm::StringLiteral 29e07a7fd5SMatthias Springer BufferizationDialect::kBufferLayoutAttrName; 30e07a7fd5SMatthias Springer 31*3474d10eSMatthias Springer /// Attribute name used to mark escaping behavior of buffer allocations. 32*3474d10eSMatthias Springer constexpr const ::llvm::StringLiteral BufferizationDialect::kEscapeAttrName; 33*3474d10eSMatthias Springer 3457470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 3591072b74SButygin // Bufferization Dialect Interfaces 3691072b74SButygin //===----------------------------------------------------------------------===// 3791072b74SButygin 3891072b74SButygin namespace { 3991072b74SButygin struct BufferizationInlinerInterface : public DialectInlinerInterface { 4091072b74SButygin using DialectInlinerInterface::DialectInlinerInterface; 4191072b74SButygin 4291072b74SButygin /// Operations in Bufferization dialect are always legal to inline. isLegalToInline__anon4220dc3d0111::BufferizationInlinerInterface4391072b74SButygin bool isLegalToInline(Operation *, Region *, bool, 4491072b74SButygin BlockAndValueMapping &) const final { 4591072b74SButygin return true; 4691072b74SButygin } 4791072b74SButygin }; 48be0a7e9fSMehdi Amini } // namespace 4991072b74SButygin 5091072b74SButygin //===----------------------------------------------------------------------===// 5157470abcSAlexander Belyaev // Bufferization Dialect 5257470abcSAlexander Belyaev //===----------------------------------------------------------------------===// 5357470abcSAlexander Belyaev initialize()5457470abcSAlexander Belyaevvoid mlir::bufferization::BufferizationDialect::initialize() { 5557470abcSAlexander Belyaev addOperations< 5657470abcSAlexander Belyaev #define GET_OP_LIST 5757470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc" 5857470abcSAlexander Belyaev >(); 5991072b74SButygin addInterfaces<BufferizationInlinerInterface>(); 6057470abcSAlexander Belyaev } 61e07a7fd5SMatthias Springer 62e07a7fd5SMatthias Springer LogicalResult verifyOperationAttribute(Operation * op,NamedAttribute attr)63e07a7fd5SMatthias SpringerBufferizationDialect::verifyOperationAttribute(Operation *op, 64e07a7fd5SMatthias Springer NamedAttribute attr) { 65e07a7fd5SMatthias Springer using bufferization::BufferizableOpInterface; 66e07a7fd5SMatthias Springer 67e07a7fd5SMatthias Springer if (attr.getName() == kWritableAttrName) { 68e07a7fd5SMatthias Springer if (!attr.getValue().isa<BoolAttr>()) { 69e07a7fd5SMatthias Springer return op->emitError() << "'" << kWritableAttrName 70e07a7fd5SMatthias Springer << "' is expected to be a boolean attribute"; 71e07a7fd5SMatthias Springer } 72e07a7fd5SMatthias Springer if (!isa<FunctionOpInterface>(op)) 73e07a7fd5SMatthias Springer return op->emitError() << "expected " << attr.getName() 74e07a7fd5SMatthias Springer << " to be used on function-like operations"; 75e07a7fd5SMatthias Springer return success(); 76e07a7fd5SMatthias Springer } 77e07a7fd5SMatthias Springer if (attr.getName() == kBufferLayoutAttrName) { 78e07a7fd5SMatthias Springer if (!attr.getValue().isa<AffineMapAttr>()) { 79e07a7fd5SMatthias Springer return op->emitError() << "'" << kBufferLayoutAttrName 80e07a7fd5SMatthias Springer << "' is expected to be a affine map attribute"; 81e07a7fd5SMatthias Springer } 82e07a7fd5SMatthias Springer if (!isa<FunctionOpInterface>(op)) 83e07a7fd5SMatthias Springer return op->emitError() << "expected " << attr.getName() 84e07a7fd5SMatthias Springer << " to be used on function-like operations"; 85e07a7fd5SMatthias Springer return success(); 86e07a7fd5SMatthias Springer } 87*3474d10eSMatthias Springer if (attr.getName() == kEscapeAttrName) { 88*3474d10eSMatthias Springer auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 89*3474d10eSMatthias Springer if (!arrayAttr) 90*3474d10eSMatthias Springer return op->emitError() << "'" << kEscapeAttrName 91*3474d10eSMatthias Springer << "' is expected to be a bool array attribute"; 92*3474d10eSMatthias Springer if (arrayAttr.size() != op->getNumResults()) 93*3474d10eSMatthias Springer return op->emitError() 94*3474d10eSMatthias Springer << "'" << kEscapeAttrName 95*3474d10eSMatthias Springer << "' has wrong number of elements, expected " 96*3474d10eSMatthias Springer << op->getNumResults() << ", got " << arrayAttr.size(); 97*3474d10eSMatthias Springer auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op); 98*3474d10eSMatthias Springer if (!bufferizableOp) 99*3474d10eSMatthias Springer return op->emitError() 100*3474d10eSMatthias Springer << "'" << kEscapeAttrName << "' only valid on bufferizable ops"; 101*3474d10eSMatthias Springer for (const auto &it : llvm::enumerate(arrayAttr)) { 102*3474d10eSMatthias Springer auto attr = it.value(); 103*3474d10eSMatthias Springer auto boolAttr = attr.dyn_cast<BoolAttr>(); 104*3474d10eSMatthias Springer if (!boolAttr) 105*3474d10eSMatthias Springer return op->emitError() << "'" << kEscapeAttrName 106*3474d10eSMatthias Springer << "' is expected to be a bool array attribute"; 107*3474d10eSMatthias Springer if (!boolAttr.getValue()) 108*3474d10eSMatthias Springer continue; 109*3474d10eSMatthias Springer if (!op->getResult(it.index()).getType().isa<TensorType>()) 110*3474d10eSMatthias Springer return op->emitError() 111*3474d10eSMatthias Springer << "'" << kEscapeAttrName << "' only valid for tensor results"; 112*3474d10eSMatthias Springer if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index()))) 113*3474d10eSMatthias Springer return op->emitError() << "'" << kEscapeAttrName 114*3474d10eSMatthias Springer << "' only valid for allocation results"; 115*3474d10eSMatthias Springer } 116*3474d10eSMatthias Springer return success(); 117*3474d10eSMatthias Springer } 118e07a7fd5SMatthias Springer 119e07a7fd5SMatthias Springer return op->emitError() << "attribute '" << attr.getName() 120e07a7fd5SMatthias Springer << "' not supported by the bufferization dialect"; 121e07a7fd5SMatthias Springer } 122