1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
11 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/IR/FunctionInterfaces.h"
15 #include "mlir/Transforms/InliningUtils.h"
16 
17 using namespace mlir;
18 using namespace mlir::bufferization;
19 
20 #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.cpp.inc"
21 
22 /// Attribute name used to mark function arguments who's buffers can be written
23 /// to during One-Shot Module Bufferize.
24 constexpr const ::llvm::StringLiteral BufferizationDialect::kWritableAttrName;
25 
26 /// Attribute name used to mark the bufferization layout for region arguments
27 /// during One-Shot Module Bufferize.
28 constexpr const ::llvm::StringLiteral
29     BufferizationDialect::kBufferLayoutAttrName;
30 
31 /// Attribute name used to mark escaping behavior of buffer allocations.
32 constexpr const ::llvm::StringLiteral BufferizationDialect::kEscapeAttrName;
33 
34 //===----------------------------------------------------------------------===//
35 // Bufferization Dialect Interfaces
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct BufferizationInlinerInterface : public DialectInlinerInterface {
40   using DialectInlinerInterface::DialectInlinerInterface;
41 
42   /// Operations in Bufferization dialect are always legal to inline.
isLegalToInline__anon4220dc3d0111::BufferizationInlinerInterface43   bool isLegalToInline(Operation *, Region *, bool,
44                        BlockAndValueMapping &) const final {
45     return true;
46   }
47 };
48 } // namespace
49 
50 //===----------------------------------------------------------------------===//
51 // Bufferization Dialect
52 //===----------------------------------------------------------------------===//
53 
initialize()54 void mlir::bufferization::BufferizationDialect::initialize() {
55   addOperations<
56 #define GET_OP_LIST
57 #include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
58       >();
59   addInterfaces<BufferizationInlinerInterface>();
60 }
61 
62 LogicalResult
verifyOperationAttribute(Operation * op,NamedAttribute attr)63 BufferizationDialect::verifyOperationAttribute(Operation *op,
64                                                NamedAttribute attr) {
65   using bufferization::BufferizableOpInterface;
66 
67   if (attr.getName() == kWritableAttrName) {
68     if (!attr.getValue().isa<BoolAttr>()) {
69       return op->emitError() << "'" << kWritableAttrName
70                              << "' is expected to be a boolean attribute";
71     }
72     if (!isa<FunctionOpInterface>(op))
73       return op->emitError() << "expected " << attr.getName()
74                              << " to be used on function-like operations";
75     return success();
76   }
77   if (attr.getName() == kBufferLayoutAttrName) {
78     if (!attr.getValue().isa<AffineMapAttr>()) {
79       return op->emitError() << "'" << kBufferLayoutAttrName
80                              << "' is expected to be a affine map attribute";
81     }
82     if (!isa<FunctionOpInterface>(op))
83       return op->emitError() << "expected " << attr.getName()
84                              << " to be used on function-like operations";
85     return success();
86   }
87   if (attr.getName() == kEscapeAttrName) {
88     auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
89     if (!arrayAttr)
90       return op->emitError() << "'" << kEscapeAttrName
91                              << "' is expected to be a bool array attribute";
92     if (arrayAttr.size() != op->getNumResults())
93       return op->emitError()
94              << "'" << kEscapeAttrName
95              << "' has wrong number of elements, expected "
96              << op->getNumResults() << ", got " << arrayAttr.size();
97     auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
98     if (!bufferizableOp)
99       return op->emitError()
100              << "'" << kEscapeAttrName << "' only valid on bufferizable ops";
101     for (const auto &it : llvm::enumerate(arrayAttr)) {
102       auto attr = it.value();
103       auto boolAttr = attr.dyn_cast<BoolAttr>();
104       if (!boolAttr)
105         return op->emitError() << "'" << kEscapeAttrName
106                                << "' is expected to be a bool array attribute";
107       if (!boolAttr.getValue())
108         continue;
109       if (!op->getResult(it.index()).getType().isa<TensorType>())
110         return op->emitError()
111                << "'" << kEscapeAttrName << "' only valid for tensor results";
112       if (!bufferizableOp.bufferizesToAllocation(op->getOpResult(it.index())))
113         return op->emitError() << "'" << kEscapeAttrName
114                                << "' only valid for allocation results";
115     }
116     return success();
117   }
118 
119   return op->emitError() << "attribute '" << attr.getName()
120                          << "' not supported by the bufferization dialect";
121 }
122