1 //===- Context.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Tools/PDLL/ODS/Context.h" 10 #include "mlir/Tools/PDLL/ODS/Constraint.h" 11 #include "mlir/Tools/PDLL/ODS/Dialect.h" 12 #include "mlir/Tools/PDLL/ODS/Operation.h" 13 #include "llvm/Support/ScopedPrinter.h" 14 #include "llvm/Support/raw_ostream.h" 15 16 using namespace mlir; 17 using namespace mlir::pdll::ods; 18 19 //===----------------------------------------------------------------------===// 20 // Context 21 //===----------------------------------------------------------------------===// 22 23 Context::Context() = default; 24 Context::~Context() = default; 25 26 const AttributeConstraint & 27 Context::insertAttributeConstraint(StringRef name, StringRef summary, 28 StringRef cppClass) { 29 std::unique_ptr<AttributeConstraint> &constraint = attributeConstraints[name]; 30 if (!constraint) { 31 constraint.reset(new AttributeConstraint(name, summary, cppClass)); 32 } else { 33 assert(constraint->getCppClass() == cppClass && 34 constraint->getSummary() == summary && 35 "constraint with the same name was already registered with a " 36 "different class"); 37 } 38 return *constraint; 39 } 40 41 const TypeConstraint &Context::insertTypeConstraint(StringRef name, 42 StringRef summary, 43 StringRef cppClass) { 44 std::unique_ptr<TypeConstraint> &constraint = typeConstraints[name]; 45 if (!constraint) 46 constraint.reset(new TypeConstraint(name, summary, cppClass)); 47 return *constraint; 48 } 49 50 Dialect &Context::insertDialect(StringRef name) { 51 std::unique_ptr<Dialect> &dialect = dialects[name]; 52 if (!dialect) 53 dialect.reset(new Dialect(name)); 54 return *dialect; 55 } 56 57 const Dialect *Context::lookupDialect(StringRef name) const { 58 auto it = dialects.find(name); 59 return it == dialects.end() ? nullptr : &*it->second; 60 } 61 62 std::pair<Operation *, bool> 63 Context::insertOperation(StringRef name, StringRef summary, StringRef desc, 64 StringRef nativeClassName, 65 bool supportsResultTypeInferrence, SMLoc loc) { 66 std::pair<StringRef, StringRef> dialectAndName = name.split('.'); 67 return insertDialect(dialectAndName.first) 68 .insertOperation(name, summary, desc, nativeClassName, 69 supportsResultTypeInferrence, loc); 70 } 71 72 const Operation *Context::lookupOperation(StringRef name) const { 73 std::pair<StringRef, StringRef> dialectAndName = name.split('.'); 74 if (const Dialect *dialect = lookupDialect(dialectAndName.first)) 75 return dialect->lookupOperation(name); 76 return nullptr; 77 } 78 79 template <typename T> 80 SmallVector<T *> sortMapByName(const llvm::StringMap<std::unique_ptr<T>> &map) { 81 SmallVector<T *> storage; 82 for (auto &entry : map) 83 storage.push_back(entry.second.get()); 84 llvm::sort(storage, [](const auto &lhs, const auto &rhs) { 85 return lhs->getName() < rhs->getName(); 86 }); 87 return storage; 88 } 89 90 void Context::print(raw_ostream &os) const { 91 auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) { 92 switch (kind) { 93 case VariableLengthKind::Optional: 94 os << "Optional<" << cst << ">"; 95 break; 96 case VariableLengthKind::Single: 97 os << cst; 98 break; 99 case VariableLengthKind::Variadic: 100 os << "Variadic<" << cst << ">"; 101 break; 102 } 103 }; 104 105 llvm::ScopedPrinter printer(os); 106 llvm::DictScope odsScope(printer, "ODSContext"); 107 for (const Dialect *dialect : sortMapByName(dialects)) { 108 printer.startLine() << "Dialect `" << dialect->getName() << "` {\n"; 109 printer.indent(); 110 111 for (const Operation *op : sortMapByName(dialect->getOperations())) { 112 printer.startLine() << "Operation `" << op->getName() << "` {\n"; 113 printer.indent(); 114 115 // Attributes. 116 ArrayRef<Attribute> attributes = op->getAttributes(); 117 if (!attributes.empty()) { 118 printer.startLine() << "Attributes { "; 119 llvm::interleaveComma(attributes, os, [&](const Attribute &attr) { 120 os << attr.getName() << " : "; 121 122 auto kind = attr.isOptional() ? VariableLengthKind::Optional 123 : VariableLengthKind::Single; 124 printVariableLengthCst(attr.getConstraint().getDemangledName(), kind); 125 }); 126 os << " }\n"; 127 } 128 129 // Operands. 130 ArrayRef<OperandOrResult> operands = op->getOperands(); 131 if (!operands.empty()) { 132 printer.startLine() << "Operands { "; 133 llvm::interleaveComma( 134 operands, os, [&](const OperandOrResult &operand) { 135 os << operand.getName() << " : "; 136 printVariableLengthCst(operand.getConstraint().getDemangledName(), 137 operand.getVariableLengthKind()); 138 }); 139 os << " }\n"; 140 } 141 142 // Results. 143 ArrayRef<OperandOrResult> results = op->getResults(); 144 if (!results.empty()) { 145 printer.startLine() << "Results { "; 146 llvm::interleaveComma(results, os, [&](const OperandOrResult &result) { 147 os << result.getName() << " : "; 148 printVariableLengthCst(result.getConstraint().getDemangledName(), 149 result.getVariableLengthKind()); 150 }); 151 os << " }\n"; 152 } 153 154 printer.objectEnd(); 155 } 156 printer.objectEnd(); 157 } 158 for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) { 159 printer.startLine() << "AttributeConstraint `" << cst->getDemangledName() 160 << "` {\n"; 161 printer.indent(); 162 163 printer.startLine() << "Summary: " << cst->getSummary() << "\n"; 164 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; 165 printer.objectEnd(); 166 } 167 for (const TypeConstraint *cst : sortMapByName(typeConstraints)) { 168 printer.startLine() << "TypeConstraint `" << cst->getDemangledName() 169 << "` {\n"; 170 printer.indent(); 171 172 printer.startLine() << "Summary: " << cst->getSummary() << "\n"; 173 printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; 174 printer.objectEnd(); 175 } 176 printer.objectEnd(); 177 } 178