1 //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "FormatGen.h"
10 #include "llvm/ADT/StringSwitch.h"
11 #include "llvm/Support/SourceMgr.h"
12 #include "llvm/TableGen/Error.h"
13
14 using namespace mlir;
15 using namespace mlir::tblgen;
16
17 //===----------------------------------------------------------------------===//
18 // FormatToken
19 //===----------------------------------------------------------------------===//
20
getLoc() const21 SMLoc FormatToken::getLoc() const {
22 return SMLoc::getFromPointer(spelling.data());
23 }
24
25 //===----------------------------------------------------------------------===//
26 // FormatLexer
27 //===----------------------------------------------------------------------===//
28
FormatLexer(llvm::SourceMgr & mgr,SMLoc loc)29 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
30 : mgr(mgr), loc(loc),
31 curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
32 curPtr(curBuffer.begin()) {}
33
emitError(SMLoc loc,const Twine & msg)34 FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
35 mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
36 llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
37 "in custom assembly format for this operation");
38 return formToken(FormatToken::error, loc.getPointer());
39 }
40
emitError(const char * loc,const Twine & msg)41 FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
42 return emitError(SMLoc::getFromPointer(loc), msg);
43 }
44
emitErrorAndNote(SMLoc loc,const Twine & msg,const Twine & note)45 FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
46 const Twine ¬e) {
47 mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
48 llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
49 "in custom assembly format for this operation");
50 mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
51 return formToken(FormatToken::error, loc.getPointer());
52 }
53
getNextChar()54 int FormatLexer::getNextChar() {
55 char curChar = *curPtr++;
56 switch (curChar) {
57 default:
58 return (unsigned char)curChar;
59 case 0: {
60 // A nul character in the stream is either the end of the current buffer or
61 // a random nul in the file. Disambiguate that here.
62 if (curPtr - 1 != curBuffer.end())
63 return 0;
64
65 // Otherwise, return end of file.
66 --curPtr;
67 return EOF;
68 }
69 case '\n':
70 case '\r':
71 // Handle the newline character by ignoring it and incrementing the line
72 // count. However, be careful about 'dos style' files with \n\r in them.
73 // Only treat a \n\r or \r\n as a single line.
74 if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
75 ++curPtr;
76 return '\n';
77 }
78 }
79
lexToken()80 FormatToken FormatLexer::lexToken() {
81 const char *tokStart = curPtr;
82
83 // This always consumes at least one character.
84 int curChar = getNextChar();
85 switch (curChar) {
86 default:
87 // Handle identifiers: [a-zA-Z_]
88 if (isalpha(curChar) || curChar == '_')
89 return lexIdentifier(tokStart);
90
91 // Unknown character, emit an error.
92 return emitError(tokStart, "unexpected character");
93 case EOF:
94 // Return EOF denoting the end of lexing.
95 return formToken(FormatToken::eof, tokStart);
96
97 // Lex punctuation.
98 case '^':
99 return formToken(FormatToken::caret, tokStart);
100 case ':':
101 return formToken(FormatToken::colon, tokStart);
102 case ',':
103 return formToken(FormatToken::comma, tokStart);
104 case '=':
105 return formToken(FormatToken::equal, tokStart);
106 case '<':
107 return formToken(FormatToken::less, tokStart);
108 case '>':
109 return formToken(FormatToken::greater, tokStart);
110 case '?':
111 return formToken(FormatToken::question, tokStart);
112 case '(':
113 return formToken(FormatToken::l_paren, tokStart);
114 case ')':
115 return formToken(FormatToken::r_paren, tokStart);
116 case '*':
117 return formToken(FormatToken::star, tokStart);
118 case '|':
119 return formToken(FormatToken::pipe, tokStart);
120
121 // Ignore whitespace characters.
122 case 0:
123 case ' ':
124 case '\t':
125 case '\n':
126 return lexToken();
127
128 case '`':
129 return lexLiteral(tokStart);
130 case '$':
131 return lexVariable(tokStart);
132 }
133 }
134
lexLiteral(const char * tokStart)135 FormatToken FormatLexer::lexLiteral(const char *tokStart) {
136 assert(curPtr[-1] == '`');
137
138 // Lex a literal surrounded by ``.
139 while (const char curChar = *curPtr++) {
140 if (curChar == '`')
141 return formToken(FormatToken::literal, tokStart);
142 }
143 return emitError(curPtr - 1, "unexpected end of file in literal");
144 }
145
lexVariable(const char * tokStart)146 FormatToken FormatLexer::lexVariable(const char *tokStart) {
147 if (!isalpha(curPtr[0]) && curPtr[0] != '_')
148 return emitError(curPtr - 1, "expected variable name");
149
150 // Otherwise, consume the rest of the characters.
151 while (isalnum(*curPtr) || *curPtr == '_')
152 ++curPtr;
153 return formToken(FormatToken::variable, tokStart);
154 }
155
lexIdentifier(const char * tokStart)156 FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
157 // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
158 while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
159 ++curPtr;
160
161 // Check to see if this identifier is a keyword.
162 StringRef str(tokStart, curPtr - tokStart);
163 auto kind =
164 StringSwitch<FormatToken::Kind>(str)
165 .Case("attr-dict", FormatToken::kw_attr_dict)
166 .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
167 .Case("custom", FormatToken::kw_custom)
168 .Case("functional-type", FormatToken::kw_functional_type)
169 .Case("oilist", FormatToken::kw_oilist)
170 .Case("operands", FormatToken::kw_operands)
171 .Case("params", FormatToken::kw_params)
172 .Case("ref", FormatToken::kw_ref)
173 .Case("regions", FormatToken::kw_regions)
174 .Case("results", FormatToken::kw_results)
175 .Case("struct", FormatToken::kw_struct)
176 .Case("successors", FormatToken::kw_successors)
177 .Case("type", FormatToken::kw_type)
178 .Case("qualified", FormatToken::kw_qualified)
179 .Default(FormatToken::identifier);
180 return FormatToken(kind, str);
181 }
182
183 //===----------------------------------------------------------------------===//
184 // FormatParser
185 //===----------------------------------------------------------------------===//
186
187 FormatElement::~FormatElement() = default;
188
189 FormatParser::~FormatParser() = default;
190
parse()191 FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
192 SMLoc loc = curToken.getLoc();
193
194 // Parse each of the format elements into the main format.
195 std::vector<FormatElement *> elements;
196 while (curToken.getKind() != FormatToken::eof) {
197 FailureOr<FormatElement *> element = parseElement(TopLevelContext);
198 if (failed(element))
199 return failure();
200 elements.push_back(*element);
201 }
202
203 // Verify the format.
204 if (failed(verify(loc, elements)))
205 return failure();
206 return elements;
207 }
208
209 //===----------------------------------------------------------------------===//
210 // Element Parsing
211
parseElement(Context ctx)212 FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
213 if (curToken.is(FormatToken::literal))
214 return parseLiteral(ctx);
215 if (curToken.is(FormatToken::variable))
216 return parseVariable(ctx);
217 if (curToken.isKeyword())
218 return parseDirective(ctx);
219 if (curToken.is(FormatToken::l_paren))
220 return parseOptionalGroup(ctx);
221 return emitError(curToken.getLoc(),
222 "expected literal, variable, directive, or optional group");
223 }
224
parseLiteral(Context ctx)225 FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
226 FormatToken tok = curToken;
227 SMLoc loc = tok.getLoc();
228 consumeToken();
229
230 if (ctx != TopLevelContext) {
231 return emitError(
232 loc,
233 "literals may only be used in the top-level section of the format");
234 }
235 // Get the spelling without the surrounding backticks.
236 StringRef value = tok.getSpelling();
237 // Prevents things like `$arg0` or empty literals (when a literal is expected
238 // but not found) from getting segmentation faults.
239 if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`')
240 return emitError(tok.getLoc(), "expected literal, but got '" + value + "'");
241 value = value.drop_front().drop_back();
242
243 // The parsed literal is a space element (`` or ` `) or a newline.
244 if (value.empty() || value == " " || value == "\\n")
245 return create<WhitespaceElement>(value);
246
247 // Check that the parsed literal is valid.
248 if (!isValidLiteral(value, [&](Twine msg) {
249 (void)emitError(loc, "expected valid literal but got '" + value +
250 "': " + msg);
251 }))
252 return failure();
253 return create<LiteralElement>(value);
254 }
255
parseVariable(Context ctx)256 FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
257 FormatToken tok = curToken;
258 SMLoc loc = tok.getLoc();
259 consumeToken();
260
261 // Get the name of the variable without the leading `$`.
262 StringRef name = tok.getSpelling().drop_front();
263 return parseVariableImpl(loc, name, ctx);
264 }
265
parseDirective(Context ctx)266 FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
267 FormatToken tok = curToken;
268 SMLoc loc = tok.getLoc();
269 consumeToken();
270
271 if (tok.is(FormatToken::kw_custom))
272 return parseCustomDirective(loc, ctx);
273 return parseDirectiveImpl(loc, tok.getKind(), ctx);
274 }
275
parseOptionalGroup(Context ctx)276 FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
277 SMLoc loc = curToken.getLoc();
278 consumeToken();
279 if (ctx != TopLevelContext) {
280 return emitError(loc,
281 "optional groups can only be used as top-level elements");
282 }
283
284 // Parse the child elements for this optional group.
285 std::vector<FormatElement *> thenElements, elseElements;
286 Optional<unsigned> anchorIndex;
287 do {
288 FailureOr<FormatElement *> element = parseElement(TopLevelContext);
289 if (failed(element))
290 return failure();
291 // Check for an anchor.
292 if (curToken.is(FormatToken::caret)) {
293 if (anchorIndex)
294 return emitError(curToken.getLoc(), "only one element can be marked as "
295 "the anchor of an optional group");
296 anchorIndex = thenElements.size();
297 consumeToken();
298 }
299 thenElements.push_back(*element);
300 } while (!curToken.is(FormatToken::r_paren));
301 consumeToken();
302
303 // Parse the `else` elements of this optional group.
304 if (curToken.is(FormatToken::colon)) {
305 consumeToken();
306 if (failed(
307 parseToken(FormatToken::l_paren,
308 "expected '(' to start else branch of optional group")))
309 return failure();
310 do {
311 FailureOr<FormatElement *> element = parseElement(TopLevelContext);
312 if (failed(element))
313 return failure();
314 elseElements.push_back(*element);
315 } while (!curToken.is(FormatToken::r_paren));
316 consumeToken();
317 }
318 if (failed(parseToken(FormatToken::question,
319 "expected '?' after optional group")))
320 return failure();
321
322 // The optional group is required to have an anchor.
323 if (!anchorIndex)
324 return emitError(loc, "optional group has no anchor element");
325
326 // Verify the child elements.
327 if (failed(verifyOptionalGroupElements(loc, thenElements, anchorIndex)) ||
328 failed(verifyOptionalGroupElements(loc, elseElements, llvm::None)))
329 return failure();
330
331 // Get the first parsable element. It must be an element that can be
332 // optionally-parsed.
333 auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) {
334 return isa<WhitespaceElement>(element);
335 });
336 if (!isa<LiteralElement, VariableElement>(*parseBegin)) {
337 return emitError(loc, "first parsable element of an optional group must be "
338 "a literal or variable");
339 }
340
341 unsigned parseStart = std::distance(thenElements.begin(), parseBegin);
342 return create<OptionalElement>(std::move(thenElements),
343 std::move(elseElements), *anchorIndex,
344 parseStart);
345 }
346
parseCustomDirective(SMLoc loc,Context ctx)347 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
348 Context ctx) {
349 if (ctx != TopLevelContext)
350 return emitError(loc, "'custom' is only valid as a top-level directive");
351
352 FailureOr<FormatToken> nameTok;
353 if (failed(parseToken(FormatToken::less,
354 "expected '<' before custom directive name")) ||
355 failed(nameTok =
356 parseToken(FormatToken::identifier,
357 "expected custom directive name identifier")) ||
358 failed(parseToken(FormatToken::greater,
359 "expected '>' after custom directive name")) ||
360 failed(parseToken(FormatToken::l_paren,
361 "expected '(' before custom directive parameters")))
362 return failure();
363
364 // Parse the arguments.
365 std::vector<FormatElement *> arguments;
366 while (true) {
367 FailureOr<FormatElement *> argument = parseElement(CustomDirectiveContext);
368 if (failed(argument))
369 return failure();
370 arguments.push_back(*argument);
371 if (!curToken.is(FormatToken::comma))
372 break;
373 consumeToken();
374 }
375
376 if (failed(parseToken(FormatToken::r_paren,
377 "expected ')' after custom directive parameters")))
378 return failure();
379
380 if (failed(verifyCustomDirectiveArguments(loc, arguments)))
381 return failure();
382 return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
383 }
384
385 //===----------------------------------------------------------------------===//
386 // Utility Functions
387 //===----------------------------------------------------------------------===//
388
shouldEmitSpaceBefore(StringRef value,bool lastWasPunctuation)389 bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
390 bool lastWasPunctuation) {
391 if (value.size() != 1 && value != "->")
392 return true;
393 if (lastWasPunctuation)
394 return !StringRef(">)}],").contains(value.front());
395 return !StringRef("<>(){}[],").contains(value.front());
396 }
397
canFormatStringAsKeyword(StringRef value,function_ref<void (Twine)> emitError)398 bool mlir::tblgen::canFormatStringAsKeyword(
399 StringRef value, function_ref<void(Twine)> emitError) {
400 if (!isalpha(value.front()) && value.front() != '_') {
401 if (emitError)
402 emitError("valid keyword starts with a letter or '_'");
403 return false;
404 }
405 if (!llvm::all_of(value.drop_front(), [](char c) {
406 return isalnum(c) || c == '_' || c == '$' || c == '.';
407 })) {
408 if (emitError)
409 emitError(
410 "keywords should contain only alphanum, '_', '$', or '.' characters");
411 return false;
412 }
413 return true;
414 }
415
isValidLiteral(StringRef value,function_ref<void (Twine)> emitError)416 bool mlir::tblgen::isValidLiteral(StringRef value,
417 function_ref<void(Twine)> emitError) {
418 if (value.empty()) {
419 if (emitError)
420 emitError("literal can't be empty");
421 return false;
422 }
423 char front = value.front();
424
425 // If there is only one character, this must either be punctuation or a
426 // single character bare identifier.
427 if (value.size() == 1) {
428 StringRef bare = "_:,=<>()[]{}?+*";
429 if (isalpha(front) || bare.contains(front))
430 return true;
431 if (emitError)
432 emitError("single character literal must be a letter or one of '" + bare +
433 "'");
434 return false;
435 }
436 // Check the punctuation that are larger than a single character.
437 if (value == "->")
438 return true;
439
440 // Otherwise, this must be an identifier.
441 return canFormatStringAsKeyword(value, emitError);
442 }
443
444 //===----------------------------------------------------------------------===//
445 // Commandline Options
446 //===----------------------------------------------------------------------===//
447
448 llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
449 "asmformat-error-is-fatal",
450 llvm::cl::desc("Emit a fatal error if format parsing fails"),
451 llvm::cl::init(true));
452