1 //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the dialect symbols, such as extended
10 // attributes and types.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "AsmParserImpl.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/DialectImplementation.h"
18 #include "llvm/Support/SourceMgr.h"
19
20 using namespace mlir;
21 using namespace mlir::detail;
22 using llvm::MemoryBuffer;
23 using llvm::SourceMgr;
24
25 namespace {
26 /// This class provides the main implementation of the DialectAsmParser that
27 /// allows for dialects to parse attributes and types. This allows for dialect
28 /// hooking into the main MLIR parsing logic.
29 class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
30 public:
CustomDialectAsmParser(StringRef fullSpec,Parser & parser)31 CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
32 : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
33 fullSpec(fullSpec) {}
34 ~CustomDialectAsmParser() override = default;
35
36 /// Returns the full specification of the symbol being parsed. This allows
37 /// for using a separate parser if necessary.
getFullSymbolSpec() const38 StringRef getFullSymbolSpec() const override { return fullSpec; }
39
40 private:
41 /// The full symbol specification.
42 StringRef fullSpec;
43 };
44 } // namespace
45
46 ///
47 /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
48 /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
49 /// | '(' pretty-dialect-sym-contents+ ')'
50 /// | '[' pretty-dialect-sym-contents+ ']'
51 /// | '{' pretty-dialect-sym-contents+ '}'
52 /// | '[^[<({>\])}\0]+'
53 ///
parseDialectSymbolBody(StringRef & body,bool & isCodeCompletion)54 ParseResult Parser::parseDialectSymbolBody(StringRef &body,
55 bool &isCodeCompletion) {
56 // Symbol bodies are a relatively unstructured format that contains a series
57 // of properly nested punctuation, with anything else in the middle. Scan
58 // ahead to find it and consume it if successful, otherwise emit an error.
59 const char *curPtr = getTokenSpelling().data();
60
61 // Scan over the nested punctuation, bailing out on error and consuming until
62 // we find the end. We know that we're currently looking at the '<', so we can
63 // go until we find the matching '>' character.
64 assert(*curPtr == '<');
65 SmallVector<char, 8> nestedPunctuation;
66 const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
67 do {
68 // Handle code completions, which may appear in the middle of the symbol
69 // body.
70 if (curPtr == codeCompleteLoc) {
71 isCodeCompletion = true;
72 nestedPunctuation.clear();
73 break;
74 }
75
76 char c = *curPtr++;
77 switch (c) {
78 case '\0':
79 // This also handles the EOF case.
80 if (!nestedPunctuation.empty()) {
81 return emitError() << "unbalanced '" << nestedPunctuation.back()
82 << "' character in pretty dialect name";
83 }
84 return emitError("unexpected nul or EOF in pretty dialect name");
85 case '<':
86 case '[':
87 case '(':
88 case '{':
89 nestedPunctuation.push_back(c);
90 continue;
91
92 case '-':
93 // The sequence `->` is treated as special token.
94 if (*curPtr == '>')
95 ++curPtr;
96 continue;
97
98 case '>':
99 if (nestedPunctuation.pop_back_val() != '<')
100 return emitError("unbalanced '>' character in pretty dialect name");
101 break;
102 case ']':
103 if (nestedPunctuation.pop_back_val() != '[')
104 return emitError("unbalanced ']' character in pretty dialect name");
105 break;
106 case ')':
107 if (nestedPunctuation.pop_back_val() != '(')
108 return emitError("unbalanced ')' character in pretty dialect name");
109 break;
110 case '}':
111 if (nestedPunctuation.pop_back_val() != '{')
112 return emitError("unbalanced '}' character in pretty dialect name");
113 break;
114 case '"': {
115 // Dispatch to the lexer to lex past strings.
116 resetToken(curPtr - 1);
117 curPtr = state.curToken.getEndLoc().getPointer();
118
119 // Handle code completions, which may appear in the middle of the symbol
120 // body.
121 if (state.curToken.isCodeCompletion()) {
122 isCodeCompletion = true;
123 nestedPunctuation.clear();
124 break;
125 }
126
127 // Otherwise, ensure this token was actually a string.
128 if (state.curToken.isNot(Token::string))
129 return failure();
130 break;
131 }
132
133 default:
134 continue;
135 }
136 } while (!nestedPunctuation.empty());
137
138 // Ok, we succeeded, remember where we stopped, reset the lexer to know it is
139 // consuming all this stuff, and return.
140 resetToken(curPtr);
141
142 unsigned length = curPtr - body.begin();
143 body = StringRef(body.data(), length);
144 return success();
145 }
146
147 /// Parse an extended dialect symbol.
148 template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
parseExtendedSymbol(Parser & p,SymbolAliasMap & aliases,CreateFn && createSymbol)149 static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
150 CreateFn &&createSymbol) {
151 Token tok = p.getToken();
152
153 // Handle code completion of the extended symbol.
154 StringRef identifier = tok.getSpelling().drop_front();
155 if (tok.isCodeCompletion() && identifier.empty())
156 return p.codeCompleteDialectSymbol(aliases);
157
158 // Parse the dialect namespace.
159 SMLoc loc = p.getToken().getLoc();
160 p.consumeToken();
161
162 // Check to see if this is a pretty name.
163 StringRef dialectName;
164 StringRef symbolData;
165 std::tie(dialectName, symbolData) = identifier.split('.');
166 bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
167
168 // Check to see if the symbol has trailing data, i.e. has an immediately
169 // following '<'.
170 bool hasTrailingData =
171 p.getToken().is(Token::less) &&
172 identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
173
174 // If there is no '<' token following this, and if the typename contains no
175 // dot, then we are parsing a symbol alias.
176 if (!hasTrailingData && !isPrettyName) {
177 // Check for an alias for this type.
178 auto aliasIt = aliases.find(identifier);
179 if (aliasIt == aliases.end())
180 return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
181 "'"),
182 nullptr);
183 return aliasIt->second;
184 }
185
186 // If this isn't an alias, we are parsing a dialect-specific symbol. If the
187 // name contains a dot, then this is the "pretty" form. If not, it is the
188 // verbose form that looks like <...>.
189 if (!isPrettyName) {
190 // Point the symbol data to the end of the dialect name to start.
191 symbolData = StringRef(dialectName.end(), 0);
192
193 // Parse the body of the symbol.
194 bool isCodeCompletion = false;
195 if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
196 return nullptr;
197 symbolData = symbolData.drop_front();
198
199 // If the body contained a code completion it won't have the trailing `>`
200 // token, so don't drop it.
201 if (!isCodeCompletion)
202 symbolData = symbolData.drop_back();
203 } else {
204 loc = SMLoc::getFromPointer(symbolData.data());
205
206 // If the dialect's symbol is followed immediately by a <, then lex the body
207 // of it into prettyName.
208 if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
209 return nullptr;
210 }
211
212 return createSymbol(dialectName, symbolData, loc);
213 }
214
215 /// Parse an extended attribute.
216 ///
217 /// extended-attribute ::= (dialect-attribute | attribute-alias)
218 /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>`
219 /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body?
220 /// attribute-alias ::= `#` alias-name
221 ///
parseExtendedAttr(Type type)222 Attribute Parser::parseExtendedAttr(Type type) {
223 MLIRContext *ctx = getContext();
224 Attribute attr = parseExtendedSymbol<Attribute>(
225 *this, state.symbols.attributeAliasDefinitions,
226 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
227 // Parse an optional trailing colon type.
228 Type attrType = type;
229 if (consumeIf(Token::colon) && !(attrType = parseType()))
230 return Attribute();
231
232 // If we found a registered dialect, then ask it to parse the attribute.
233 if (Dialect *dialect =
234 builder.getContext()->getOrLoadDialect(dialectName)) {
235 // Temporarily reset the lexer to let the dialect parse the attribute.
236 const char *curLexerPos = getToken().getLoc().getPointer();
237 resetToken(symbolData.data());
238
239 // Parse the attribute.
240 CustomDialectAsmParser customParser(symbolData, *this);
241 Attribute attr = dialect->parseAttribute(customParser, attrType);
242 resetToken(curLexerPos);
243 return attr;
244 }
245
246 // Otherwise, form a new opaque attribute.
247 return OpaqueAttr::getChecked(
248 [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
249 symbolData, attrType ? attrType : NoneType::get(ctx));
250 });
251
252 // Ensure that the attribute has the same type as requested.
253 if (attr && type && attr.getType() != type) {
254 emitError("attribute type different than expected: expected ")
255 << type << ", but got " << attr.getType();
256 return nullptr;
257 }
258 return attr;
259 }
260
261 /// Parse an extended type.
262 ///
263 /// extended-type ::= (dialect-type | type-alias)
264 /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
265 /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
266 /// type-alias ::= `!` alias-name
267 ///
parseExtendedType()268 Type Parser::parseExtendedType() {
269 MLIRContext *ctx = getContext();
270 return parseExtendedSymbol<Type>(
271 *this, state.symbols.typeAliasDefinitions,
272 [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
273 // If we found a registered dialect, then ask it to parse the type.
274 if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
275 // Temporarily reset the lexer to let the dialect parse the type.
276 const char *curLexerPos = getToken().getLoc().getPointer();
277 resetToken(symbolData.data());
278
279 // Parse the type.
280 CustomDialectAsmParser customParser(symbolData, *this);
281 Type type = dialect->parseType(customParser);
282 resetToken(curLexerPos);
283 return type;
284 }
285
286 // Otherwise, form a new opaque type.
287 return OpaqueType::getChecked([&] { return emitError(loc); },
288 StringAttr::get(ctx, dialectName),
289 symbolData);
290 });
291 }
292
293 //===----------------------------------------------------------------------===//
294 // mlir::parseAttribute/parseType
295 //===----------------------------------------------------------------------===//
296
297 /// Parses a symbol, of type 'T', and returns it if parsing was successful. If
298 /// parsing failed, nullptr is returned. The number of bytes read from the input
299 /// string is returned in 'numRead'.
300 template <typename T, typename ParserFn>
parseSymbol(StringRef inputStr,MLIRContext * context,size_t & numRead,ParserFn && parserFn)301 static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead,
302 ParserFn &&parserFn) {
303 SourceMgr sourceMgr;
304 auto memBuffer = MemoryBuffer::getMemBuffer(
305 inputStr, /*BufferName=*/"<mlir_parser_buffer>",
306 /*RequiresNullTerminator=*/false);
307 sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
308 SymbolState aliasState;
309 ParserConfig config(context);
310 ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
311 /*codeCompleteContext=*/nullptr);
312 Parser parser(state);
313
314 SourceMgrDiagnosticHandler handler(
315 const_cast<llvm::SourceMgr &>(parser.getSourceMgr()),
316 parser.getContext());
317 Token startTok = parser.getToken();
318 T symbol = parserFn(parser);
319 if (!symbol)
320 return T();
321
322 // Provide the number of bytes that were read.
323 Token endTok = parser.getToken();
324 numRead = static_cast<size_t>(endTok.getLoc().getPointer() -
325 startTok.getLoc().getPointer());
326 return symbol;
327 }
328
parseAttribute(StringRef attrStr,MLIRContext * context)329 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) {
330 size_t numRead = 0;
331 return parseAttribute(attrStr, context, numRead);
332 }
parseAttribute(StringRef attrStr,Type type)333 Attribute mlir::parseAttribute(StringRef attrStr, Type type) {
334 size_t numRead = 0;
335 return parseAttribute(attrStr, type, numRead);
336 }
337
parseAttribute(StringRef attrStr,MLIRContext * context,size_t & numRead)338 Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
339 size_t &numRead) {
340 return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) {
341 return parser.parseAttribute();
342 });
343 }
parseAttribute(StringRef attrStr,Type type,size_t & numRead)344 Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) {
345 return parseSymbol<Attribute>(
346 attrStr, type.getContext(), numRead,
347 [type](Parser &parser) { return parser.parseAttribute(type); });
348 }
349
parseType(StringRef typeStr,MLIRContext * context)350 Type mlir::parseType(StringRef typeStr, MLIRContext *context) {
351 size_t numRead = 0;
352 return parseType(typeStr, context, numRead);
353 }
354
parseType(StringRef typeStr,MLIRContext * context,size_t & numRead)355 Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) {
356 return parseSymbol<Type>(typeStr, context, numRead,
357 [](Parser &parser) { return parser.parseType(); });
358 }
359