1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the parser for the dialect symbols, such as extended 10 // attributes and types. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "AsmParserImpl.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Dialect.h" 17 #include "mlir/IR/DialectImplementation.h" 18 #include "llvm/Support/SourceMgr.h" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 using llvm::MemoryBuffer; 23 using llvm::SourceMgr; 24 25 namespace { 26 /// This class provides the main implementation of the DialectAsmParser that 27 /// allows for dialects to parse attributes and types. This allows for dialect 28 /// hooking into the main MLIR parsing logic. 29 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> { 30 public: 31 CustomDialectAsmParser(StringRef fullSpec, Parser &parser) 32 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser), 33 fullSpec(fullSpec) {} 34 ~CustomDialectAsmParser() override = default; 35 36 /// Returns the full specification of the symbol being parsed. This allows 37 /// for using a separate parser if necessary. 38 StringRef getFullSymbolSpec() const override { return fullSpec; } 39 40 private: 41 /// The full symbol specification. 42 StringRef fullSpec; 43 }; 44 } // namespace 45 46 /// 47 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' 48 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body 49 /// | '(' pretty-dialect-sym-contents+ ')' 50 /// | '[' pretty-dialect-sym-contents+ ']' 51 /// | '{' pretty-dialect-sym-contents+ '}' 52 /// | '[^[<({>\])}\0]+' 53 /// 54 ParseResult Parser::parseDialectSymbolBody(StringRef &body, 55 bool &isCodeCompletion) { 56 // Symbol bodies are a relatively unstructured format that contains a series 57 // of properly nested punctuation, with anything else in the middle. Scan 58 // ahead to find it and consume it if successful, otherwise emit an error. 59 const char *curPtr = getTokenSpelling().data(); 60 61 // Scan over the nested punctuation, bailing out on error and consuming until 62 // we find the end. We know that we're currently looking at the '<', so we can 63 // go until we find the matching '>' character. 64 assert(*curPtr == '<'); 65 SmallVector<char, 8> nestedPunctuation; 66 const char *codeCompleteLoc = state.lex.getCodeCompleteLoc(); 67 do { 68 // Handle code completions, which may appear in the middle of the symbol 69 // body. 70 if (curPtr == codeCompleteLoc) { 71 isCodeCompletion = true; 72 nestedPunctuation.clear(); 73 break; 74 } 75 76 char c = *curPtr++; 77 switch (c) { 78 case '\0': 79 // This also handles the EOF case. 80 if (!nestedPunctuation.empty()) { 81 return emitError() << "unbalanced '" << nestedPunctuation.back() 82 << "' character in pretty dialect name"; 83 } 84 return emitError("unexpected nul or EOF in pretty dialect name"); 85 case '<': 86 case '[': 87 case '(': 88 case '{': 89 nestedPunctuation.push_back(c); 90 continue; 91 92 case '-': 93 // The sequence `->` is treated as special token. 94 if (*curPtr == '>') 95 ++curPtr; 96 continue; 97 98 case '>': 99 if (nestedPunctuation.pop_back_val() != '<') 100 return emitError("unbalanced '>' character in pretty dialect name"); 101 break; 102 case ']': 103 if (nestedPunctuation.pop_back_val() != '[') 104 return emitError("unbalanced ']' character in pretty dialect name"); 105 break; 106 case ')': 107 if (nestedPunctuation.pop_back_val() != '(') 108 return emitError("unbalanced ')' character in pretty dialect name"); 109 break; 110 case '}': 111 if (nestedPunctuation.pop_back_val() != '{') 112 return emitError("unbalanced '}' character in pretty dialect name"); 113 break; 114 case '"': { 115 // Dispatch to the lexer to lex past strings. 116 resetToken(curPtr - 1); 117 curPtr = state.curToken.getEndLoc().getPointer(); 118 119 // Handle code completions, which may appear in the middle of the symbol 120 // body. 121 if (state.curToken.isCodeCompletion()) { 122 isCodeCompletion = true; 123 nestedPunctuation.clear(); 124 break; 125 } 126 127 // Otherwise, ensure this token was actually a string. 128 if (state.curToken.isNot(Token::string)) 129 return failure(); 130 break; 131 } 132 133 default: 134 continue; 135 } 136 } while (!nestedPunctuation.empty()); 137 138 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is 139 // consuming all this stuff, and return. 140 resetToken(curPtr); 141 142 unsigned length = curPtr - body.begin(); 143 body = StringRef(body.data(), length); 144 return success(); 145 } 146 147 /// Parse an extended dialect symbol. 148 template <typename Symbol, typename SymbolAliasMap, typename CreateFn> 149 static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases, 150 CreateFn &&createSymbol) { 151 Token tok = p.getToken(); 152 153 // Handle code completion of the extended symbol. 154 StringRef identifier = tok.getSpelling().drop_front(); 155 if (tok.isCodeCompletion() && identifier.empty()) 156 return p.codeCompleteDialectSymbol(aliases); 157 158 // Parse the dialect namespace. 159 SMLoc loc = p.getToken().getLoc(); 160 p.consumeToken(); 161 162 // Check to see if this is a pretty name. 163 StringRef dialectName; 164 StringRef symbolData; 165 std::tie(dialectName, symbolData) = identifier.split('.'); 166 bool isPrettyName = !symbolData.empty() || identifier.back() == '.'; 167 168 // Check to see if the symbol has trailing data, i.e. has an immediately 169 // following '<'. 170 bool hasTrailingData = 171 p.getToken().is(Token::less) && 172 identifier.bytes_end() == p.getTokenSpelling().bytes_begin(); 173 174 // If there is no '<' token following this, and if the typename contains no 175 // dot, then we are parsing a symbol alias. 176 if (!hasTrailingData && !isPrettyName) { 177 // Check for an alias for this type. 178 auto aliasIt = aliases.find(identifier); 179 if (aliasIt == aliases.end()) 180 return (p.emitWrongTokenError("undefined symbol alias id '" + identifier + 181 "'"), 182 nullptr); 183 return aliasIt->second; 184 } 185 186 // If this isn't an alias, we are parsing a dialect-specific symbol. If the 187 // name contains a dot, then this is the "pretty" form. If not, it is the 188 // verbose form that looks like <...>. 189 if (!isPrettyName) { 190 // Point the symbol data to the end of the dialect name to start. 191 symbolData = StringRef(dialectName.end(), 0); 192 193 // Parse the body of the symbol. 194 bool isCodeCompletion = false; 195 if (p.parseDialectSymbolBody(symbolData, isCodeCompletion)) 196 return nullptr; 197 symbolData = symbolData.drop_front(); 198 199 // If the body contained a code completion it won't have the trailing `>` 200 // token, so don't drop it. 201 if (!isCodeCompletion) 202 symbolData = symbolData.drop_back(); 203 } else { 204 loc = SMLoc::getFromPointer(symbolData.data()); 205 206 // If the dialect's symbol is followed immediately by a <, then lex the body 207 // of it into prettyName. 208 if (hasTrailingData && p.parseDialectSymbolBody(symbolData)) 209 return nullptr; 210 } 211 212 return createSymbol(dialectName, symbolData, loc); 213 } 214 215 /// Parse an extended attribute. 216 /// 217 /// extended-attribute ::= (dialect-attribute | attribute-alias) 218 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` 219 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? 220 /// attribute-alias ::= `#` alias-name 221 /// 222 Attribute Parser::parseExtendedAttr(Type type) { 223 MLIRContext *ctx = getContext(); 224 Attribute attr = parseExtendedSymbol<Attribute>( 225 *this, state.symbols.attributeAliasDefinitions, 226 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute { 227 // Parse an optional trailing colon type. 228 Type attrType = type; 229 if (consumeIf(Token::colon) && !(attrType = parseType())) 230 return Attribute(); 231 232 // If we found a registered dialect, then ask it to parse the attribute. 233 if (Dialect *dialect = 234 builder.getContext()->getOrLoadDialect(dialectName)) { 235 // Temporarily reset the lexer to let the dialect parse the attribute. 236 const char *curLexerPos = getToken().getLoc().getPointer(); 237 resetToken(symbolData.data()); 238 239 // Parse the attribute. 240 CustomDialectAsmParser customParser(symbolData, *this); 241 Attribute attr = dialect->parseAttribute(customParser, attrType); 242 resetToken(curLexerPos); 243 return attr; 244 } 245 246 // Otherwise, form a new opaque attribute. 247 return OpaqueAttr::getChecked( 248 [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName), 249 symbolData, attrType ? attrType : NoneType::get(ctx)); 250 }); 251 252 // Ensure that the attribute has the same type as requested. 253 if (attr && type && attr.getType() != type) { 254 emitError("attribute type different than expected: expected ") 255 << type << ", but got " << attr.getType(); 256 return nullptr; 257 } 258 return attr; 259 } 260 261 /// Parse an extended type. 262 /// 263 /// extended-type ::= (dialect-type | type-alias) 264 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` 265 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? 266 /// type-alias ::= `!` alias-name 267 /// 268 Type Parser::parseExtendedType() { 269 MLIRContext *ctx = getContext(); 270 return parseExtendedSymbol<Type>( 271 *this, state.symbols.typeAliasDefinitions, 272 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { 273 // If we found a registered dialect, then ask it to parse the type. 274 if (auto *dialect = ctx->getOrLoadDialect(dialectName)) { 275 // Temporarily reset the lexer to let the dialect parse the type. 276 const char *curLexerPos = getToken().getLoc().getPointer(); 277 resetToken(symbolData.data()); 278 279 // Parse the type. 280 CustomDialectAsmParser customParser(symbolData, *this); 281 Type type = dialect->parseType(customParser); 282 resetToken(curLexerPos); 283 return type; 284 } 285 286 // Otherwise, form a new opaque type. 287 return OpaqueType::getChecked([&] { return emitError(loc); }, 288 StringAttr::get(ctx, dialectName), 289 symbolData); 290 }); 291 } 292 293 //===----------------------------------------------------------------------===// 294 // mlir::parseAttribute/parseType 295 //===----------------------------------------------------------------------===// 296 297 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If 298 /// parsing failed, nullptr is returned. The number of bytes read from the input 299 /// string is returned in 'numRead'. 300 template <typename T, typename ParserFn> 301 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, 302 ParserFn &&parserFn) { 303 SourceMgr sourceMgr; 304 auto memBuffer = MemoryBuffer::getMemBuffer( 305 inputStr, /*BufferName=*/"<mlir_parser_buffer>", 306 /*RequiresNullTerminator=*/false); 307 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); 308 SymbolState aliasState; 309 ParserConfig config(context); 310 ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr, 311 /*codeCompleteContext=*/nullptr); 312 Parser parser(state); 313 314 SourceMgrDiagnosticHandler handler( 315 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), 316 parser.getContext()); 317 Token startTok = parser.getToken(); 318 T symbol = parserFn(parser); 319 if (!symbol) 320 return T(); 321 322 // Provide the number of bytes that were read. 323 Token endTok = parser.getToken(); 324 numRead = static_cast<size_t>(endTok.getLoc().getPointer() - 325 startTok.getLoc().getPointer()); 326 return symbol; 327 } 328 329 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { 330 size_t numRead = 0; 331 return parseAttribute(attrStr, context, numRead); 332 } 333 Attribute mlir::parseAttribute(StringRef attrStr, Type type) { 334 size_t numRead = 0; 335 return parseAttribute(attrStr, type, numRead); 336 } 337 338 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, 339 size_t &numRead) { 340 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { 341 return parser.parseAttribute(); 342 }); 343 } 344 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { 345 return parseSymbol<Attribute>( 346 attrStr, type.getContext(), numRead, 347 [type](Parser &parser) { return parser.parseAttribute(type); }); 348 } 349 350 Type mlir::parseType(StringRef typeStr, MLIRContext *context) { 351 size_t numRead = 0; 352 return parseType(typeStr, context, numRead); 353 } 354 355 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { 356 return parseSymbol<Type>(typeStr, context, numRead, 357 [](Parser &parser) { return parser.parseType(); }); 358 } 359