1 //===- Context.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/ODS/Context.h"
10 #include "mlir/Tools/PDLL/ODS/Constraint.h"
11 #include "mlir/Tools/PDLL/ODS/Dialect.h"
12 #include "mlir/Tools/PDLL/ODS/Operation.h"
13 #include "llvm/Support/ScopedPrinter.h"
14 #include "llvm/Support/raw_ostream.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ods;
18 
19 //===----------------------------------------------------------------------===//
20 // Context
21 //===----------------------------------------------------------------------===//
22 
23 Context::Context() = default;
24 Context::~Context() = default;
25 
26 const AttributeConstraint &
27 Context::insertAttributeConstraint(StringRef name, StringRef summary,
28                                    StringRef cppClass) {
29   std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
30   if (!constraint) {
31     constraint.reset(new AttributeConstraint(name, summary, cppClass));
32   } else {
33     assert(constraint->getCppClass() == cppClass &&
34            constraint->getSummary() == summary &&
35            "constraint with the same name was already registered with a "
36            "different class");
37   }
38   return *constraint;
39 }
40 
41 const TypeConstraint &Context::insertTypeConstraint(StringRef name,
42                                                     StringRef summary,
43                                                     StringRef cppClass) {
44   std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
45   if (!constraint)
46     constraint.reset(new TypeConstraint(name, summary, cppClass));
47   return *constraint;
48 }
49 
50 Dialect &Context::insertDialect(StringRef name) {
51   std::unique_ptr<Dialect> &dialect = dialects[name];
52   if (!dialect)
53     dialect.reset(new Dialect(name));
54   return *dialect;
55 }
56 
57 const Dialect *Context::lookupDialect(StringRef name) const {
58   auto it = dialects.find(name);
59   return it == dialects.end() ? nullptr : &*it->second;
60 }
61 
62 std::pair<Operation *, bool> Context::insertOperation(StringRef name,
63                                                       StringRef summary,
64                                                       StringRef desc,
65                                                       SMLoc loc) {
66   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
67   return insertDialect(dialectAndName.first)
68       .insertOperation(name, summary, desc, loc);
69 }
70 
71 const Operation *Context::lookupOperation(StringRef name) const {
72   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
73   if (const Dialect *dialect = lookupDialect(dialectAndName.first))
74     return dialect->lookupOperation(name);
75   return nullptr;
76 }
77 
78 template <typename T>
79 SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
80   SmallVector<T *> storage;
81   for (auto &entry : map)
82     storage.push_back(entry.second.get());
83   llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
84     return lhs->getName() < rhs->getName();
85   });
86   return storage;
87 }
88 
89 void Context::print(raw_ostream &os) const {
90   auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
91     switch (kind) {
92     case VariableLengthKind::Optional:
93       os << "Optional<" << cst << ">";
94       break;
95     case VariableLengthKind::Single:
96       os << cst;
97       break;
98     case VariableLengthKind::Variadic:
99       os << "Variadic<" << cst << ">";
100       break;
101     }
102   };
103 
104   llvm::ScopedPrinter printer(os);
105   llvm::DictScope odsScope(printer, "ODSContext");
106   for (const Dialect *dialect : sortMapByName(dialects)) {
107     printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
108     printer.indent();
109 
110     for (const Operation *op : sortMapByName(dialect->getOperations())) {
111       printer.startLine() << "Operation `" << op->getName() << "` {\n";
112       printer.indent();
113 
114       // Attributes.
115       ArrayRef<Attribute> attributes = op->getAttributes();
116       if (!attributes.empty()) {
117         printer.startLine() << "Attributes { ";
118         llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
119           os << attr.getName() << " : ";
120 
121           auto kind = attr.isOptional() ? VariableLengthKind::Optional
122                                         : VariableLengthKind::Single;
123           printVariableLengthCst(attr.getConstraint().getName(), kind);
124         });
125         os << " }\n";
126       }
127 
128       // Operands.
129       ArrayRef<OperandOrResult> operands = op->getOperands();
130       if (!operands.empty()) {
131         printer.startLine() << "Operands { ";
132         llvm::interleaveComma(
133             operands, os, [&](const OperandOrResult &operand) {
134               os << operand.getName() << " : ";
135               printVariableLengthCst(operand.getConstraint().getName(),
136                                      operand.getVariableLengthKind());
137             });
138         os << " }\n";
139       }
140 
141       // Results.
142       ArrayRef<OperandOrResult> results = op->getResults();
143       if (!results.empty()) {
144         printer.startLine() << "Results { ";
145         llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
146           os << result.getName() << " : ";
147           printVariableLengthCst(result.getConstraint().getName(),
148                                  result.getVariableLengthKind());
149         });
150         os << " }\n";
151       }
152 
153       printer.objectEnd();
154     }
155     printer.objectEnd();
156   }
157   for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
158     printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n";
159     printer.indent();
160 
161     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
162     printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
163     printer.objectEnd();
164   }
165   for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
166     printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n";
167     printer.indent();
168 
169     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
170     printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
171     printer.objectEnd();
172   }
173   printer.objectEnd();
174 }
175