1 //===- Context.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Tools/PDLL/ODS/Context.h"
10 #include "mlir/Tools/PDLL/ODS/Constraint.h"
11 #include "mlir/Tools/PDLL/ODS/Dialect.h"
12 #include "mlir/Tools/PDLL/ODS/Operation.h"
13 #include "llvm/Support/ScopedPrinter.h"
14 #include "llvm/Support/raw_ostream.h"
15
16 using namespace mlir;
17 using namespace mlir::pdll::ods;
18
19 //===----------------------------------------------------------------------===//
20 // Context
21 //===----------------------------------------------------------------------===//
22
23 Context::Context() = default;
24 Context::~Context() = default;
25
26 const AttributeConstraint &
insertAttributeConstraint(StringRef name,StringRef summary,StringRef cppClass)27 Context::insertAttributeConstraint(StringRef name, StringRef summary,
28 StringRef cppClass) {
29 std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
30 if (!constraint) {
31 constraint.reset(new AttributeConstraint(name, summary, cppClass));
32 } else {
33 assert(constraint->getCppClass() == cppClass &&
34 constraint->getSummary() == summary &&
35 "constraint with the same name was already registered with a "
36 "different class");
37 }
38 return *constraint;
39 }
40
insertTypeConstraint(StringRef name,StringRef summary,StringRef cppClass)41 const TypeConstraint &Context::insertTypeConstraint(StringRef name,
42 StringRef summary,
43 StringRef cppClass) {
44 std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
45 if (!constraint)
46 constraint.reset(new TypeConstraint(name, summary, cppClass));
47 return *constraint;
48 }
49
insertDialect(StringRef name)50 Dialect &Context::insertDialect(StringRef name) {
51 std::unique_ptr<Dialect> &dialect = dialects[name];
52 if (!dialect)
53 dialect.reset(new Dialect(name));
54 return *dialect;
55 }
56
lookupDialect(StringRef name) const57 const Dialect *Context::lookupDialect(StringRef name) const {
58 auto it = dialects.find(name);
59 return it == dialects.end() ? nullptr : &*it->second;
60 }
61
62 std::pair<Operation *, bool>
insertOperation(StringRef name,StringRef summary,StringRef desc,StringRef nativeClassName,bool supportsResultTypeInferrence,SMLoc loc)63 Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
64 StringRef nativeClassName,
65 bool supportsResultTypeInferrence, SMLoc loc) {
66 std::pair<StringRef, StringRef> dialectAndName = name.split('.');
67 return insertDialect(dialectAndName.first)
68 .insertOperation(name, summary, desc, nativeClassName,
69 supportsResultTypeInferrence, loc);
70 }
71
lookupOperation(StringRef name) const72 const Operation *Context::lookupOperation(StringRef name) const {
73 std::pair<StringRef, StringRef> dialectAndName = name.split('.');
74 if (const Dialect *dialect = lookupDialect(dialectAndName.first))
75 return dialect->lookupOperation(name);
76 return nullptr;
77 }
78
79 template <typename T>
sortMapByName(const llvm::StringMap<std::unique_ptr<T>> & map)80 SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
81 SmallVector<T *> storage;
82 for (auto &entry : map)
83 storage.push_back(entry.second.get());
84 llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
85 return lhs->getName() < rhs->getName();
86 });
87 return storage;
88 }
89
print(raw_ostream & os) const90 void Context::print(raw_ostream &os) const {
91 auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
92 switch (kind) {
93 case VariableLengthKind::Optional:
94 os << "Optional<" << cst << ">";
95 break;
96 case VariableLengthKind::Single:
97 os << cst;
98 break;
99 case VariableLengthKind::Variadic:
100 os << "Variadic<" << cst << ">";
101 break;
102 }
103 };
104
105 llvm::ScopedPrinter printer(os);
106 llvm::DictScope odsScope(printer, "ODSContext");
107 for (const Dialect *dialect : sortMapByName(dialects)) {
108 printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
109 printer.indent();
110
111 for (const Operation *op : sortMapByName(dialect->getOperations())) {
112 printer.startLine() << "Operation `" << op->getName() << "` {\n";
113 printer.indent();
114
115 // Attributes.
116 ArrayRef<Attribute> attributes = op->getAttributes();
117 if (!attributes.empty()) {
118 printer.startLine() << "Attributes { ";
119 llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
120 os << attr.getName() << " : ";
121
122 auto kind = attr.isOptional() ? VariableLengthKind::Optional
123 : VariableLengthKind::Single;
124 printVariableLengthCst(attr.getConstraint().getDemangledName(), kind);
125 });
126 os << " }\n";
127 }
128
129 // Operands.
130 ArrayRef<OperandOrResult> operands = op->getOperands();
131 if (!operands.empty()) {
132 printer.startLine() << "Operands { ";
133 llvm::interleaveComma(
134 operands, os, [&](const OperandOrResult &operand) {
135 os << operand.getName() << " : ";
136 printVariableLengthCst(operand.getConstraint().getDemangledName(),
137 operand.getVariableLengthKind());
138 });
139 os << " }\n";
140 }
141
142 // Results.
143 ArrayRef<OperandOrResult> results = op->getResults();
144 if (!results.empty()) {
145 printer.startLine() << "Results { ";
146 llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
147 os << result.getName() << " : ";
148 printVariableLengthCst(result.getConstraint().getDemangledName(),
149 result.getVariableLengthKind());
150 });
151 os << " }\n";
152 }
153
154 printer.objectEnd();
155 }
156 printer.objectEnd();
157 }
158 for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
159 printer.startLine() << "AttributeConstraint `" << cst->getDemangledName()
160 << "` {\n";
161 printer.indent();
162
163 printer.startLine() << "Summary: " << cst->getSummary() << "\n";
164 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
165 printer.objectEnd();
166 }
167 for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
168 printer.startLine() << "TypeConstraint `" << cst->getDemangledName()
169 << "` {\n";
170 printer.indent();
171
172 printer.startLine() << "Summary: " << cst->getSummary() << "\n";
173 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
174 printer.objectEnd();
175 }
176 printer.objectEnd();
177 }
178