157470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
257470abcSAlexander Belyaev //
357470abcSAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
457470abcSAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
557470abcSAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
657470abcSAlexander Belyaev //
757470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
857470abcSAlexander Belyaev 
9ffdbecccSMatthias Springer #include "mlir/Dialect/Affine/IR/AffineOps.h"
10*3474d10eSMatthias Springer #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1157470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12eda6f907SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
13eda6f907SRiver Riddle #include "mlir/Dialect/Tensor/IR/Tensor.h"
14e07a7fd5SMatthias Springer #include "mlir/IR/FunctionInterfaces.h"
1591072b74SButygin #include "mlir/Transforms/InliningUtils.h"
1657470abcSAlexander Belyaev 
1757470abcSAlexander Belyaev using namespace mlir;
1857470abcSAlexander Belyaev using namespace mlir::bufferization;
1957470abcSAlexander Belyaev 
2057470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc"
2157470abcSAlexander Belyaev 
22e07a7fd5SMatthias Springer /// Attribute name used to mark function arguments who's buffers can be written
23e07a7fd5SMatthias Springer /// to during One-Shot Module Bufferize.
24e07a7fd5SMatthias Springer constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName;
25e07a7fd5SMatthias Springer 
26e07a7fd5SMatthias Springer /// Attribute name used to mark the bufferization layout for region arguments
27e07a7fd5SMatthias Springer /// during One-Shot Module Bufferize.
28e07a7fd5SMatthias Springer constexpr const ::llvm::StringLiteral
29e07a7fd5SMatthias Springer     BufferizationDialect::kBufferLayoutAttrName;
30e07a7fd5SMatthias Springer 
31*3474d10eSMatthias Springer /// Attribute name used to mark escaping behavior of buffer allocations.
32*3474d10eSMatthias Springer constexpr const ::llvm::StringLiteral BufferizationDialect::kEscapeAttrName;
33*3474d10eSMatthias Springer 
3457470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
3591072b74SButygin // Bufferization Dialect Interfaces
3691072b74SButygin //===----------------------------------------------------------------------===//
3791072b74SButygin 
3891072b74SButygin namespace {
3991072b74SButygin struct BufferizationInlinerInterface : public DialectInlinerInterface {
4091072b74SButygin   using DialectInlinerInterface::DialectInlinerInterface;
4191072b74SButygin 
4291072b74SButygin   /// Operations in Bufferization dialect are always legal to inline.
isLegalToInline__anon4220dc3d0111::BufferizationInlinerInterface4391072b74SButygin   bool isLegalToInline(Operation *, Region *, bool,
4491072b74SButygin                        BlockAndValueMapping &) const final {
4591072b74SButygin     return true;
4691072b74SButygin   }
4791072b74SButygin };
48be0a7e9fSMehdi Amini } // namespace
4991072b74SButygin 
5091072b74SButygin //===----------------------------------------------------------------------===//
5157470abcSAlexander Belyaev // Bufferization Dialect
5257470abcSAlexander Belyaev //===----------------------------------------------------------------------===//
5357470abcSAlexander Belyaev 
initialize()5457470abcSAlexander Belyaev void mlir::bufferization::BufferizationDialect::initialize() {
5557470abcSAlexander Belyaev   addOperations<
5657470abcSAlexander Belyaev #define GET_OP_LIST
5757470abcSAlexander Belyaev #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
5857470abcSAlexander Belyaev       >();
5991072b74SButygin   addInterfaces<BufferizationInlinerInterface>();
6057470abcSAlexander Belyaev }
61e07a7fd5SMatthias Springer 
62e07a7fd5SMatthias Springer LogicalResult
verifyOperationAttribute(Operation * op,NamedAttribute attr)63e07a7fd5SMatthias Springer BufferizationDialect::verifyOperationAttribute(Operation *op,
64e07a7fd5SMatthias Springer                                                NamedAttribute attr) {
65e07a7fd5SMatthias Springer   using bufferization::BufferizableOpInterface;
66e07a7fd5SMatthias Springer 
67e07a7fd5SMatthias Springer   if (attr.getName() == kWritableAttrName) {
68e07a7fd5SMatthias Springer     if (!attr.getValue().isa<BoolAttr>()) {
69e07a7fd5SMatthias Springer       return op->emitError() << "'" << kWritableAttrName
70e07a7fd5SMatthias Springer                              << "' is expected to be a boolean attribute";
71e07a7fd5SMatthias Springer     }
72e07a7fd5SMatthias Springer     if (!isa<FunctionOpInterface>(op))
73e07a7fd5SMatthias Springer       return op->emitError() << "expected " << attr.getName()
74e07a7fd5SMatthias Springer                              << " to be used on function-like operations";
75e07a7fd5SMatthias Springer     return success();
76e07a7fd5SMatthias Springer   }
77e07a7fd5SMatthias Springer   if (attr.getName() == kBufferLayoutAttrName) {
78e07a7fd5SMatthias Springer     if (!attr.getValue().isa<AffineMapAttr>()) {
79e07a7fd5SMatthias Springer       return op->emitError() << "'" << kBufferLayoutAttrName
80e07a7fd5SMatthias Springer                              << "' is expected to be a affine map attribute";
81e07a7fd5SMatthias Springer     }
82e07a7fd5SMatthias Springer     if (!isa<FunctionOpInterface>(op))
83e07a7fd5SMatthias Springer       return op->emitError() << "expected " << attr.getName()
84e07a7fd5SMatthias Springer                              << " to be used on function-like operations";
85e07a7fd5SMatthias Springer     return success();
86e07a7fd5SMatthias Springer   }
87*3474d10eSMatthias Springer   if (attr.getName() == kEscapeAttrName) {
88*3474d10eSMatthias Springer     auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
89*3474d10eSMatthias Springer     if (!arrayAttr)
90*3474d10eSMatthias Springer       return op->emitError() << "'" << kEscapeAttrName
91*3474d10eSMatthias Springer                              << "' is expected to be a bool array attribute";
92*3474d10eSMatthias Springer     if (arrayAttr.size() != op->getNumResults())
93*3474d10eSMatthias Springer       return op->emitError()
94*3474d10eSMatthias Springer              << "'" << kEscapeAttrName
95*3474d10eSMatthias Springer              << "' has wrong number of elements, expected "
96*3474d10eSMatthias Springer              << op->getNumResults() << ", got " << arrayAttr.size();
97*3474d10eSMatthias Springer     auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
98*3474d10eSMatthias Springer     if (!bufferizableOp)
99*3474d10eSMatthias Springer       return op->emitError()
100*3474d10eSMatthias Springer              << "'" << kEscapeAttrName << "' only valid on bufferizable ops";
101*3474d10eSMatthias Springer     for (const auto &it : llvm::enumerate(arrayAttr)) {
102*3474d10eSMatthias Springer       auto attr = it.value();
103*3474d10eSMatthias Springer       auto boolAttr = attr.dyn_cast<BoolAttr>();
104*3474d10eSMatthias Springer       if (!boolAttr)
105*3474d10eSMatthias Springer         return op->emitError() << "'" << kEscapeAttrName
106*3474d10eSMatthias Springer                                << "' is expected to be a bool array attribute";
107*3474d10eSMatthias Springer       if (!boolAttr.getValue())
108*3474d10eSMatthias Springer         continue;
109*3474d10eSMatthias Springer       if (!op->getResult(it.index()).getType().isa<TensorType>())
110*3474d10eSMatthias Springer         return op->emitError()
111*3474d10eSMatthias Springer                << "'" << kEscapeAttrName << "' only valid for tensor results";
112*3474d10eSMatthias Springer       if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index())))
113*3474d10eSMatthias Springer         return op->emitError() << "'" << kEscapeAttrName
114*3474d10eSMatthias Springer                                << "' only valid for allocation results";
115*3474d10eSMatthias Springer     }
116*3474d10eSMatthias Springer     return success();
117*3474d10eSMatthias Springer   }
118e07a7fd5SMatthias Springer 
119e07a7fd5SMatthias Springer   return op->emitError() << "attribute '" << attr.getName()
120e07a7fd5SMatthias Springer                          << "' not supported by the bufferization dialect";
121e07a7fd5SMatthias Springer }
122