1 //===- Context.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/ODS/Context.h" 10 #include "mlir/Tools/PDLL/ODS/Constraint.h" 11 #include "mlir/Tools/PDLL/ODS/Dialect.h" 12 #include "mlir/Tools/PDLL/ODS/Operation.h" 13 #include "llvm/Support/ScopedPrinter.h" 14 #include "llvm/Support/raw_ostream.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll::ods; 18 19 //===----------------------------------------------------------------------===// 20 // Context 21 //===----------------------------------------------------------------------===// 22 23 Context::Context() = default; 24 Context::~Context() = default; 25 26 const AttributeConstraint & 27 Context::insertAttributeConstraint(StringRef name, StringRef summary, 28 StringRef cppClass) { 29 std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name]; 30 if (!constraint) { 31 constraint.reset(new AttributeConstraint(name, summary, cppClass)); 32 } else { 33 assert(constraint->getCppClass() == cppClass && 34 constraint->getSummary() == summary && 35 "constraint with the same name was already registered with a " 36 "different class"); 37 } 38 return *constraint; 39 } 40 41 const TypeConstraint &Context::insertTypeConstraint(StringRef name, 42 StringRef summary, 43 StringRef cppClass) { 44 std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name]; 45 if (!constraint) 46 constraint.reset(new TypeConstraint(name, summary, cppClass)); 47 return *constraint; 48 } 49 50 Dialect &Context::insertDialect(StringRef name) { 51 std::unique_ptr<Dialect> &dialect = dialects[name]; 52 if (!dialect) 53 dialect.reset(new Dialect(name)); 54 return *dialect; 55 } 56 57 const Dialect *Context::lookupDialect(StringRef name) const { 58 auto it = dialects.find(name); 59 return it == dialects.end() ? nullptr : &*it->second; 60 } 61 62 std::pair<Operation *, bool> Context::insertOperation(StringRef name, 63 StringRef summary, 64 StringRef desc, 65 SMLoc loc) { 66 std::pair<StringRef, StringRef> dialectAndName = name.split('.'); 67 return insertDialect(dialectAndName.first) 68 .insertOperation(name, summary, desc, loc); 69 } 70 71 const Operation *Context::lookupOperation(StringRef name) const { 72 std::pair<StringRef, StringRef> dialectAndName = name.split('.'); 73 if (const Dialect *dialect = lookupDialect(dialectAndName.first)) 74 return dialect->lookupOperation(name); 75 return nullptr; 76 } 77 78 template <typename T> 79 SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) { 80 SmallVector<T *> storage; 81 for (auto &entry : map) 82 storage.push_back(entry.second.get()); 83 llvm::sort(storage, [](const auto &lhs, const auto &rhs) { 84 return lhs->getName() < rhs->getName(); 85 }); 86 return storage; 87 } 88 89 void Context::print(raw_ostream &os) const { 90 auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) { 91 switch (kind) { 92 case VariableLengthKind::Optional: 93 os << "Optional<" << cst << ">"; 94 break; 95 case VariableLengthKind::Single: 96 os << cst; 97 break; 98 case VariableLengthKind::Variadic: 99 os << "Variadic<" << cst << ">"; 100 break; 101 } 102 }; 103 104 llvm::ScopedPrinter printer(os); 105 llvm::DictScope odsScope(printer, "ODSContext"); 106 for (const Dialect *dialect : sortMapByName(dialects)) { 107 printer.startLine() << "Dialect `" << dialect->getName() << "` {\n"; 108 printer.indent(); 109 110 for (const Operation *op : sortMapByName(dialect->getOperations())) { 111 printer.startLine() << "Operation `" << op->getName() << "` {\n"; 112 printer.indent(); 113 114 // Attributes. 115 ArrayRef<Attribute> attributes = op->getAttributes(); 116 if (!attributes.empty()) { 117 printer.startLine() << "Attributes { "; 118 llvm::interleaveComma(attributes, os, [&](const Attribute &attr) { 119 os << attr.getName() << " : "; 120 121 auto kind = attr.isOptional() ? VariableLengthKind::Optional 122 : VariableLengthKind::Single; 123 printVariableLengthCst(attr.getConstraint().getName(), kind); 124 }); 125 os << " }\n"; 126 } 127 128 // Operands. 129 ArrayRef<OperandOrResult> operands = op->getOperands(); 130 if (!operands.empty()) { 131 printer.startLine() << "Operands { "; 132 llvm::interleaveComma( 133 operands, os, [&](const OperandOrResult &operand) { 134 os << operand.getName() << " : "; 135 printVariableLengthCst(operand.getConstraint().getName(), 136 operand.getVariableLengthKind()); 137 }); 138 os << " }\n"; 139 } 140 141 // Results. 142 ArrayRef<OperandOrResult> results = op->getResults(); 143 if (!results.empty()) { 144 printer.startLine() << "Results { "; 145 llvm::interleaveComma(results, os, [&](const OperandOrResult &result) { 146 os << result.getName() << " : "; 147 printVariableLengthCst(result.getConstraint().getName(), 148 result.getVariableLengthKind()); 149 }); 150 os << " }\n"; 151 } 152 153 printer.objectEnd(); 154 } 155 printer.objectEnd(); 156 } 157 for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) { 158 printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n"; 159 printer.indent(); 160 161 printer.startLine() << "Summary: " << cst->getSummary() << "\n"; 162 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; 163 printer.objectEnd(); 164 } 165 for (const TypeConstraint *cst : sortMapByName(typeConstraints)) { 166 printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n"; 167 printer.indent(); 168 169 printer.startLine() << "Summary: " << cst->getSummary() << "\n"; 170 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; 171 printer.objectEnd(); 172 } 173 printer.objectEnd(); 174 } 175