1 //===- FormatGen.cpp - Utilities for custom assembly formats ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "FormatGen.h"
10 #include "llvm/ADT/StringSwitch.h"
11 #include "llvm/Support/SourceMgr.h"
12 #include "llvm/TableGen/Error.h"
13 
14 using namespace mlir;
15 using namespace mlir::tblgen;
16 
17 //===----------------------------------------------------------------------===//
18 // FormatToken
19 //===----------------------------------------------------------------------===//
20 
getLoc() const21 SMLoc FormatToken::getLoc() const {
22   return SMLoc::getFromPointer(spelling.data());
23 }
24 
25 //===----------------------------------------------------------------------===//
26 // FormatLexer
27 //===----------------------------------------------------------------------===//
28 
FormatLexer(llvm::SourceMgr & mgr,SMLoc loc)29 FormatLexer::FormatLexer(llvm::SourceMgr &mgr, SMLoc loc)
30     : mgr(mgr), loc(loc),
31       curBuffer(mgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer()),
32       curPtr(curBuffer.begin()) {}
33 
emitError(SMLoc loc,const Twine & msg)34 FormatToken FormatLexer::emitError(SMLoc loc, const Twine &msg) {
35   mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
36   llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
37                             "in custom assembly format for this operation");
38   return formToken(FormatToken::error, loc.getPointer());
39 }
40 
emitError(const char * loc,const Twine & msg)41 FormatToken FormatLexer::emitError(const char *loc, const Twine &msg) {
42   return emitError(SMLoc::getFromPointer(loc), msg);
43 }
44 
emitErrorAndNote(SMLoc loc,const Twine & msg,const Twine & note)45 FormatToken FormatLexer::emitErrorAndNote(SMLoc loc, const Twine &msg,
46                                           const Twine &note) {
47   mgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
48   llvm::SrcMgr.PrintMessage(this->loc, llvm::SourceMgr::DK_Note,
49                             "in custom assembly format for this operation");
50   mgr.PrintMessage(loc, llvm::SourceMgr::DK_Note, note);
51   return formToken(FormatToken::error, loc.getPointer());
52 }
53 
getNextChar()54 int FormatLexer::getNextChar() {
55   char curChar = *curPtr++;
56   switch (curChar) {
57   default:
58     return (unsigned char)curChar;
59   case 0: {
60     // A nul character in the stream is either the end of the current buffer or
61     // a random nul in the file. Disambiguate that here.
62     if (curPtr - 1 != curBuffer.end())
63       return 0;
64 
65     // Otherwise, return end of file.
66     --curPtr;
67     return EOF;
68   }
69   case '\n':
70   case '\r':
71     // Handle the newline character by ignoring it and incrementing the line
72     // count. However, be careful about 'dos style' files with \n\r in them.
73     // Only treat a \n\r or \r\n as a single line.
74     if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
75       ++curPtr;
76     return '\n';
77   }
78 }
79 
lexToken()80 FormatToken FormatLexer::lexToken() {
81   const char *tokStart = curPtr;
82 
83   // This always consumes at least one character.
84   int curChar = getNextChar();
85   switch (curChar) {
86   default:
87     // Handle identifiers: [a-zA-Z_]
88     if (isalpha(curChar) || curChar == '_')
89       return lexIdentifier(tokStart);
90 
91     // Unknown character, emit an error.
92     return emitError(tokStart, "unexpected character");
93   case EOF:
94     // Return EOF denoting the end of lexing.
95     return formToken(FormatToken::eof, tokStart);
96 
97   // Lex punctuation.
98   case '^':
99     return formToken(FormatToken::caret, tokStart);
100   case ':':
101     return formToken(FormatToken::colon, tokStart);
102   case ',':
103     return formToken(FormatToken::comma, tokStart);
104   case '=':
105     return formToken(FormatToken::equal, tokStart);
106   case '<':
107     return formToken(FormatToken::less, tokStart);
108   case '>':
109     return formToken(FormatToken::greater, tokStart);
110   case '?':
111     return formToken(FormatToken::question, tokStart);
112   case '(':
113     return formToken(FormatToken::l_paren, tokStart);
114   case ')':
115     return formToken(FormatToken::r_paren, tokStart);
116   case '*':
117     return formToken(FormatToken::star, tokStart);
118   case '|':
119     return formToken(FormatToken::pipe, tokStart);
120 
121   // Ignore whitespace characters.
122   case 0:
123   case ' ':
124   case '\t':
125   case '\n':
126     return lexToken();
127 
128   case '`':
129     return lexLiteral(tokStart);
130   case '$':
131     return lexVariable(tokStart);
132   }
133 }
134 
lexLiteral(const char * tokStart)135 FormatToken FormatLexer::lexLiteral(const char *tokStart) {
136   assert(curPtr[-1] == '`');
137 
138   // Lex a literal surrounded by ``.
139   while (const char curChar = *curPtr++) {
140     if (curChar == '`')
141       return formToken(FormatToken::literal, tokStart);
142   }
143   return emitError(curPtr - 1, "unexpected end of file in literal");
144 }
145 
lexVariable(const char * tokStart)146 FormatToken FormatLexer::lexVariable(const char *tokStart) {
147   if (!isalpha(curPtr[0]) && curPtr[0] != '_')
148     return emitError(curPtr - 1, "expected variable name");
149 
150   // Otherwise, consume the rest of the characters.
151   while (isalnum(*curPtr) || *curPtr == '_')
152     ++curPtr;
153   return formToken(FormatToken::variable, tokStart);
154 }
155 
lexIdentifier(const char * tokStart)156 FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
157   // Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
158   while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
159     ++curPtr;
160 
161   // Check to see if this identifier is a keyword.
162   StringRef str(tokStart, curPtr - tokStart);
163   auto kind =
164       StringSwitch<FormatToken::Kind>(str)
165           .Case("attr-dict", FormatToken::kw_attr_dict)
166           .Case("attr-dict-with-keyword", FormatToken::kw_attr_dict_w_keyword)
167           .Case("custom", FormatToken::kw_custom)
168           .Case("functional-type", FormatToken::kw_functional_type)
169           .Case("oilist", FormatToken::kw_oilist)
170           .Case("operands", FormatToken::kw_operands)
171           .Case("params", FormatToken::kw_params)
172           .Case("ref", FormatToken::kw_ref)
173           .Case("regions", FormatToken::kw_regions)
174           .Case("results", FormatToken::kw_results)
175           .Case("struct", FormatToken::kw_struct)
176           .Case("successors", FormatToken::kw_successors)
177           .Case("type", FormatToken::kw_type)
178           .Case("qualified", FormatToken::kw_qualified)
179           .Default(FormatToken::identifier);
180   return FormatToken(kind, str);
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // FormatParser
185 //===----------------------------------------------------------------------===//
186 
187 FormatElement::~FormatElement() = default;
188 
189 FormatParser::~FormatParser() = default;
190 
parse()191 FailureOr<std::vector<FormatElement *>> FormatParser::parse() {
192   SMLoc loc = curToken.getLoc();
193 
194   // Parse each of the format elements into the main format.
195   std::vector<FormatElement *> elements;
196   while (curToken.getKind() != FormatToken::eof) {
197     FailureOr<FormatElement *> element = parseElement(TopLevelContext);
198     if (failed(element))
199       return failure();
200     elements.push_back(*element);
201   }
202 
203   // Verify the format.
204   if (failed(verify(loc, elements)))
205     return failure();
206   return elements;
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // Element Parsing
211 
parseElement(Context ctx)212 FailureOr<FormatElement *> FormatParser::parseElement(Context ctx) {
213   if (curToken.is(FormatToken::literal))
214     return parseLiteral(ctx);
215   if (curToken.is(FormatToken::variable))
216     return parseVariable(ctx);
217   if (curToken.isKeyword())
218     return parseDirective(ctx);
219   if (curToken.is(FormatToken::l_paren))
220     return parseOptionalGroup(ctx);
221   return emitError(curToken.getLoc(),
222                    "expected literal, variable, directive, or optional group");
223 }
224 
parseLiteral(Context ctx)225 FailureOr<FormatElement *> FormatParser::parseLiteral(Context ctx) {
226   FormatToken tok = curToken;
227   SMLoc loc = tok.getLoc();
228   consumeToken();
229 
230   if (ctx != TopLevelContext) {
231     return emitError(
232         loc,
233         "literals may only be used in the top-level section of the format");
234   }
235   // Get the spelling without the surrounding backticks.
236   StringRef value = tok.getSpelling();
237   // Prevents things like `$arg0` or empty literals (when a literal is expected
238   // but not found) from getting segmentation faults.
239   if (value.size() < 2 || value[0] != '`' || value[value.size() - 1] != '`')
240     return emitError(tok.getLoc(), "expected literal, but got '" + value + "'");
241   value = value.drop_front().drop_back();
242 
243   // The parsed literal is a space element (`` or ` `) or a newline.
244   if (value.empty() || value == " " || value == "\\n")
245     return create<WhitespaceElement>(value);
246 
247   // Check that the parsed literal is valid.
248   if (!isValidLiteral(value, [&](Twine msg) {
249         (void)emitError(loc, "expected valid literal but got '" + value +
250                                  "': " + msg);
251       }))
252     return failure();
253   return create<LiteralElement>(value);
254 }
255 
parseVariable(Context ctx)256 FailureOr<FormatElement *> FormatParser::parseVariable(Context ctx) {
257   FormatToken tok = curToken;
258   SMLoc loc = tok.getLoc();
259   consumeToken();
260 
261   // Get the name of the variable without the leading `$`.
262   StringRef name = tok.getSpelling().drop_front();
263   return parseVariableImpl(loc, name, ctx);
264 }
265 
parseDirective(Context ctx)266 FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
267   FormatToken tok = curToken;
268   SMLoc loc = tok.getLoc();
269   consumeToken();
270 
271   if (tok.is(FormatToken::kw_custom))
272     return parseCustomDirective(loc, ctx);
273   return parseDirectiveImpl(loc, tok.getKind(), ctx);
274 }
275 
parseOptionalGroup(Context ctx)276 FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
277   SMLoc loc = curToken.getLoc();
278   consumeToken();
279   if (ctx != TopLevelContext) {
280     return emitError(loc,
281                      "optional groups can only be used as top-level elements");
282   }
283 
284   // Parse the child elements for this optional group.
285   std::vector<FormatElement *> thenElements, elseElements;
286   Optional<unsigned> anchorIndex;
287   do {
288     FailureOr<FormatElement *> element = parseElement(TopLevelContext);
289     if (failed(element))
290       return failure();
291     // Check for an anchor.
292     if (curToken.is(FormatToken::caret)) {
293       if (anchorIndex)
294         return emitError(curToken.getLoc(), "only one element can be marked as "
295                                             "the anchor of an optional group");
296       anchorIndex = thenElements.size();
297       consumeToken();
298     }
299     thenElements.push_back(*element);
300   } while (!curToken.is(FormatToken::r_paren));
301   consumeToken();
302 
303   // Parse the `else` elements of this optional group.
304   if (curToken.is(FormatToken::colon)) {
305     consumeToken();
306     if (failed(
307             parseToken(FormatToken::l_paren,
308                        "expected '(' to start else branch of optional group")))
309       return failure();
310     do {
311       FailureOr<FormatElement *> element = parseElement(TopLevelContext);
312       if (failed(element))
313         return failure();
314       elseElements.push_back(*element);
315     } while (!curToken.is(FormatToken::r_paren));
316     consumeToken();
317   }
318   if (failed(parseToken(FormatToken::question,
319                         "expected '?' after optional group")))
320     return failure();
321 
322   // The optional group is required to have an anchor.
323   if (!anchorIndex)
324     return emitError(loc, "optional group has no anchor element");
325 
326   // Verify the child elements.
327   if (failed(verifyOptionalGroupElements(loc, thenElements, anchorIndex)) ||
328       failed(verifyOptionalGroupElements(loc, elseElements, llvm::None)))
329     return failure();
330 
331   // Get the first parsable element. It must be an element that can be
332   // optionally-parsed.
333   auto parseBegin = llvm::find_if_not(thenElements, [](FormatElement *element) {
334     return isa<WhitespaceElement>(element);
335   });
336   if (!isa<LiteralElement, VariableElement>(*parseBegin)) {
337     return emitError(loc, "first parsable element of an optional group must be "
338                           "a literal or variable");
339   }
340 
341   unsigned parseStart = std::distance(thenElements.begin(), parseBegin);
342   return create<OptionalElement>(std::move(thenElements),
343                                  std::move(elseElements), *anchorIndex,
344                                  parseStart);
345 }
346 
parseCustomDirective(SMLoc loc,Context ctx)347 FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
348                                                               Context ctx) {
349   if (ctx != TopLevelContext)
350     return emitError(loc, "'custom' is only valid as a top-level directive");
351 
352   FailureOr<FormatToken> nameTok;
353   if (failed(parseToken(FormatToken::less,
354                         "expected '<' before custom directive name")) ||
355       failed(nameTok =
356                  parseToken(FormatToken::identifier,
357                             "expected custom directive name identifier")) ||
358       failed(parseToken(FormatToken::greater,
359                         "expected '>' after custom directive name")) ||
360       failed(parseToken(FormatToken::l_paren,
361                         "expected '(' before custom directive parameters")))
362     return failure();
363 
364   // Parse the arguments.
365   std::vector<FormatElement *> arguments;
366   while (true) {
367     FailureOr<FormatElement *> argument = parseElement(CustomDirectiveContext);
368     if (failed(argument))
369       return failure();
370     arguments.push_back(*argument);
371     if (!curToken.is(FormatToken::comma))
372       break;
373     consumeToken();
374   }
375 
376   if (failed(parseToken(FormatToken::r_paren,
377                         "expected ')' after custom directive parameters")))
378     return failure();
379 
380   if (failed(verifyCustomDirectiveArguments(loc, arguments)))
381     return failure();
382   return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Utility Functions
387 //===----------------------------------------------------------------------===//
388 
shouldEmitSpaceBefore(StringRef value,bool lastWasPunctuation)389 bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
390                                          bool lastWasPunctuation) {
391   if (value.size() != 1 && value != "->")
392     return true;
393   if (lastWasPunctuation)
394     return !StringRef(">)}],").contains(value.front());
395   return !StringRef("<>(){}[],").contains(value.front());
396 }
397 
canFormatStringAsKeyword(StringRef value,function_ref<void (Twine)> emitError)398 bool mlir::tblgen::canFormatStringAsKeyword(
399     StringRef value, function_ref<void(Twine)> emitError) {
400   if (!isalpha(value.front()) && value.front() != '_') {
401     if (emitError)
402       emitError("valid keyword starts with a letter or '_'");
403     return false;
404   }
405   if (!llvm::all_of(value.drop_front(), [](char c) {
406         return isalnum(c) || c == '_' || c == '$' || c == '.';
407       })) {
408     if (emitError)
409       emitError(
410           "keywords should contain only alphanum, '_', '$', or '.' characters");
411     return false;
412   }
413   return true;
414 }
415 
isValidLiteral(StringRef value,function_ref<void (Twine)> emitError)416 bool mlir::tblgen::isValidLiteral(StringRef value,
417                                   function_ref<void(Twine)> emitError) {
418   if (value.empty()) {
419     if (emitError)
420       emitError("literal can't be empty");
421     return false;
422   }
423   char front = value.front();
424 
425   // If there is only one character, this must either be punctuation or a
426   // single character bare identifier.
427   if (value.size() == 1) {
428     StringRef bare = "_:,=<>()[]{}?+*";
429     if (isalpha(front) || bare.contains(front))
430       return true;
431     if (emitError)
432       emitError("single character literal must be a letter or one of '" + bare +
433                 "'");
434     return false;
435   }
436   // Check the punctuation that are larger than a single character.
437   if (value == "->")
438     return true;
439 
440   // Otherwise, this must be an identifier.
441   return canFormatStringAsKeyword(value, emitError);
442 }
443 
444 //===----------------------------------------------------------------------===//
445 // Commandline Options
446 //===----------------------------------------------------------------------===//
447 
448 llvm::cl::opt<bool> mlir::tblgen::formatErrorIsFatal(
449     "asmformat-error-is-fatal",
450     llvm::cl::desc("Emit a fatal error if format parsing fails"),
451     llvm::cl::init(true));
452