1 //===- Context.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Tools/PDLL/ODS/Context.h"
10 #include "mlir/Tools/PDLL/ODS/Constraint.h"
11 #include "mlir/Tools/PDLL/ODS/Dialect.h"
12 #include "mlir/Tools/PDLL/ODS/Operation.h"
13 #include "llvm/Support/ScopedPrinter.h"
14 #include "llvm/Support/raw_ostream.h"
15 
16 using namespace mlir;
17 using namespace mlir::pdll::ods;
18 
19 //===----------------------------------------------------------------------===//
20 // Context
21 //===----------------------------------------------------------------------===//
22 
23 Context::Context() = default;
24 Context::~Context() = default;
25 
26 const AttributeConstraint &
insertAttributeConstraint(StringRef name,StringRef summary,StringRef cppClass)27 Context::insertAttributeConstraint(StringRef name, StringRef summary,
28                                    StringRef cppClass) {
29   std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name];
30   if (!constraint) {
31     constraint.reset(new AttributeConstraint(name, summary, cppClass));
32   } else {
33     assert(constraint->getCppClass() == cppClass &&
34            constraint->getSummary() == summary &&
35            "constraint with the same name was already registered with a "
36            "different class");
37   }
38   return *constraint;
39 }
40 
insertTypeConstraint(StringRef name,StringRef summary,StringRef cppClass)41 const TypeConstraint &Context::insertTypeConstraint(StringRef name,
42                                                     StringRef summary,
43                                                     StringRef cppClass) {
44   std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name];
45   if (!constraint)
46     constraint.reset(new TypeConstraint(name, summary, cppClass));
47   return *constraint;
48 }
49 
insertDialect(StringRef name)50 Dialect &Context::insertDialect(StringRef name) {
51   std::unique_ptr<Dialect> &dialect = dialects[name];
52   if (!dialect)
53     dialect.reset(new Dialect(name));
54   return *dialect;
55 }
56 
lookupDialect(StringRef name) const57 const Dialect *Context::lookupDialect(StringRef name) const {
58   auto it = dialects.find(name);
59   return it == dialects.end() ? nullptr : &*it->second;
60 }
61 
62 std::pair<Operation *, bool>
insertOperation(StringRef name,StringRef summary,StringRef desc,StringRef nativeClassName,bool supportsResultTypeInferrence,SMLoc loc)63 Context::insertOperation(StringRef name, StringRef summary, StringRef desc,
64                          StringRef nativeClassName,
65                          bool supportsResultTypeInferrence, SMLoc loc) {
66   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
67   return insertDialect(dialectAndName.first)
68       .insertOperation(name, summary, desc, nativeClassName,
69                        supportsResultTypeInferrence, loc);
70 }
71 
lookupOperation(StringRef name) const72 const Operation *Context::lookupOperation(StringRef name) const {
73   std::pair<StringRef, StringRef> dialectAndName = name.split('.');
74   if (const Dialect *dialect = lookupDialect(dialectAndName.first))
75     return dialect->lookupOperation(name);
76   return nullptr;
77 }
78 
79 template <typename T>
sortMapByName(const llvm::StringMap<std::unique_ptr<T>> & map)80 SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) {
81   SmallVector<T *> storage;
82   for (auto &entry : map)
83     storage.push_back(entry.second.get());
84   llvm::sort(storage, [](const auto &lhs, const auto &rhs) {
85     return lhs->getName() < rhs->getName();
86   });
87   return storage;
88 }
89 
print(raw_ostream & os) const90 void Context::print(raw_ostream &os) const {
91   auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) {
92     switch (kind) {
93     case VariableLengthKind::Optional:
94       os << "Optional<" << cst << ">";
95       break;
96     case VariableLengthKind::Single:
97       os << cst;
98       break;
99     case VariableLengthKind::Variadic:
100       os << "Variadic<" << cst << ">";
101       break;
102     }
103   };
104 
105   llvm::ScopedPrinter printer(os);
106   llvm::DictScope odsScope(printer, "ODSContext");
107   for (const Dialect *dialect : sortMapByName(dialects)) {
108     printer.startLine() << "Dialect `" << dialect->getName() << "` {\n";
109     printer.indent();
110 
111     for (const Operation *op : sortMapByName(dialect->getOperations())) {
112       printer.startLine() << "Operation `" << op->getName() << "` {\n";
113       printer.indent();
114 
115       // Attributes.
116       ArrayRef<Attribute> attributes = op->getAttributes();
117       if (!attributes.empty()) {
118         printer.startLine() << "Attributes { ";
119         llvm::interleaveComma(attributes, os, [&](const Attribute &attr) {
120           os << attr.getName() << " : ";
121 
122           auto kind = attr.isOptional() ? VariableLengthKind::Optional
123                                         : VariableLengthKind::Single;
124           printVariableLengthCst(attr.getConstraint().getDemangledName(), kind);
125         });
126         os << " }\n";
127       }
128 
129       // Operands.
130       ArrayRef<OperandOrResult> operands = op->getOperands();
131       if (!operands.empty()) {
132         printer.startLine() << "Operands { ";
133         llvm::interleaveComma(
134             operands, os, [&](const OperandOrResult &operand) {
135               os << operand.getName() << " : ";
136               printVariableLengthCst(operand.getConstraint().getDemangledName(),
137                                      operand.getVariableLengthKind());
138             });
139         os << " }\n";
140       }
141 
142       // Results.
143       ArrayRef<OperandOrResult> results = op->getResults();
144       if (!results.empty()) {
145         printer.startLine() << "Results { ";
146         llvm::interleaveComma(results, os, [&](const OperandOrResult &result) {
147           os << result.getName() << " : ";
148           printVariableLengthCst(result.getConstraint().getDemangledName(),
149                                  result.getVariableLengthKind());
150         });
151         os << " }\n";
152       }
153 
154       printer.objectEnd();
155     }
156     printer.objectEnd();
157   }
158   for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) {
159     printer.startLine() << "AttributeConstraint `" << cst->getDemangledName()
160                         << "` {\n";
161     printer.indent();
162 
163     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
164     printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
165     printer.objectEnd();
166   }
167   for (const TypeConstraint *cst : sortMapByName(typeConstraints)) {
168     printer.startLine() << "TypeConstraint `" << cst->getDemangledName()
169                         << "` {\n";
170     printer.indent();
171 
172     printer.startLine() << "Summary: " << cst->getSummary() << "\n";
173     printer.startLine() << "CppClass: " << cst->getCppClass() << "\n";
174     printer.objectEnd();
175   }
176   printer.objectEnd();
177 }
178