1 //===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file uses tablegen definitions of the LLVM IR Dialect operations to
10 // generate the code building the LLVM IR from it.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Support/LogicalResult.h"
15 #include "mlir/TableGen/Attribute.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18 
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/TableGen/Record.h"
24 #include "llvm/TableGen/TableGenBackend.h"
25 
26 using namespace llvm;
27 using namespace mlir;
28 
29 static bool emitError(const Twine &message) {
30   llvm::errs() << message << "\n";
31   return false;
32 }
33 
34 namespace {
35 // Helper structure to return a position of the substring in a string.
36 struct StringLoc {
37   size_t pos;
38   size_t length;
39 
40   // Take a substring identified by this location in the given string.
41   StringRef in(StringRef str) const { return str.substr(pos, length); }
42 
43   // A location is invalid if its position is outside the string.
44   explicit operator bool() { return pos != std::string::npos; }
45 };
46 } // namespace
47 
48 // Find the next TableGen variable in the given pattern.  These variables start
49 // with a `$` character and can contain alphanumeric characters or underscores.
50 // Return the position of the variable in the pattern and its length, including
51 // the `$` character.  The escape syntax `$$` is also detected and returned.
52 static StringLoc findNextVariable(StringRef str) {
53   size_t startPos = str.find('$');
54   if (startPos == std::string::npos)
55     return {startPos, 0};
56 
57   // If we see "$$", return immediately.
58   if (startPos != str.size() - 1 && str[startPos + 1] == '$')
59     return {startPos, 2};
60 
61   // Otherwise, the symbol spans until the first character that is not
62   // alphanumeric or '_'.
63   size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
64                                   startPos + 1);
65   if (endPos == std::string::npos)
66     endPos = str.size();
67 
68   return {startPos, endPos - startPos};
69 }
70 
71 // Check if `name` is the name of the variadic operand of `op`.  The variadic
72 // operand can only appear at the last position in the list of operands.
73 static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
74   unsigned numOperands = op.getNumOperands();
75   if (numOperands == 0)
76     return false;
77   const auto &operand = op.getOperand(numOperands - 1);
78   return operand.isVariableLength() && operand.name == name;
79 }
80 
81 // Check if `result` is a known name of a result of `op`.
82 static bool isResultName(const tblgen::Operator &op, StringRef name) {
83   for (int i = 0, e = op.getNumResults(); i < e; ++i)
84     if (op.getResultName(i) == name)
85       return true;
86   return false;
87 }
88 
89 // Check if `name` is a known name of an attribute of `op`.
90 static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
91   return llvm::any_of(
92       op.getAttributes(),
93       [name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
94 }
95 
96 // Check if `name` is a known name of an operand of `op`.
97 static bool isOperandName(const tblgen::Operator &op, StringRef name) {
98   for (int i = 0, e = op.getNumOperands(); i < e; ++i)
99     if (op.getOperand(i).name == name)
100       return true;
101   return false;
102 }
103 
104 // Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
105 // for one definition of an LLVM IR Dialect operation.  Return true on success.
106 static bool emitOneBuilder(const Record &record, raw_ostream &os) {
107   auto op = tblgen::Operator(record);
108 
109   if (!record.getValue("llvmBuilder"))
110     return emitError("no 'llvmBuilder' field for op " + op.getOperationName());
111 
112   // Return early if there is no builder specified.
113   auto builderStrRef = record.getValueAsString("llvmBuilder");
114   if (builderStrRef.empty())
115     return true;
116 
117   // Progressively create the builder string by replacing $-variables with
118   // value lookups.  Keep only the not-yet-traversed part of the builder pattern
119   // to avoid re-traversing the string multiple times.
120   std::string builder;
121   llvm::raw_string_ostream bs(builder);
122   while (auto loc = findNextVariable(builderStrRef)) {
123     auto name = loc.in(builderStrRef).drop_front();
124     auto getterName = op.getGetterName(name);
125     // First, insert the non-matched part as is.
126     bs << builderStrRef.substr(0, loc.pos);
127     // Then, rewrite the name based on its kind.
128     bool isVariadicOperand = isVariadicOperandName(op, name);
129     if (isOperandName(op, name)) {
130       auto result =
131           isVariadicOperand
132               ? formatv("moduleTranslation.lookupValues(op.{0}())", getterName)
133               : formatv("moduleTranslation.lookupValue(op.{0}())", getterName);
134       bs << result;
135     } else if (isAttributeName(op, name)) {
136       bs << formatv("op.{0}()", getterName);
137     } else if (isResultName(op, name)) {
138       bs << formatv("moduleTranslation.mapValue(op.{0}())", getterName);
139     } else if (name == "_resultType") {
140       bs << "moduleTranslation.convertType(op.getResult().getType())";
141     } else if (name == "_hasResult") {
142       bs << "opInst.getNumResults() == 1";
143     } else if (name == "_location") {
144       bs << "opInst.getLoc()";
145     } else if (name == "_numOperands") {
146       bs << "opInst.getNumOperands()";
147     } else if (name == "$") {
148       bs << '$';
149     } else {
150       return emitError(name + " is neither an argument nor a result of " +
151                        op.getOperationName());
152     }
153     // Finally, only keep the untraversed part of the string.
154     builderStrRef = builderStrRef.substr(loc.pos + loc.length);
155   }
156 
157   // Output the check and the rewritten builder string.
158   os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
159      << ">(opInst)) {\n";
160   os << bs.str() << builderStrRef << "\n";
161   os << "  return success();\n";
162   os << "}\n";
163 
164   return true;
165 }
166 
167 // Emit all builders.  Returns false on success because of the generator
168 // registration requirements.
169 static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
170   for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
171     if (!emitOneBuilder(*def, os))
172       return true;
173   }
174   return false;
175 }
176 
177 namespace {
178 // Wrapper class around a Tablegen definition of an LLVM enum attribute case.
179 class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
180 public:
181   using tblgen::EnumAttrCase::EnumAttrCase;
182 
183   // Constructs a case from a non LLVM-specific enum attribute case.
184   explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
185       : tblgen::EnumAttrCase(&other.getDef()) {}
186 
187   // Returns the C++ enumerant for the LLVM API.
188   StringRef getLLVMEnumerant() const {
189     return def->getValueAsString("llvmEnumerant");
190   }
191 };
192 
193 // Wraper class around a Tablegen definition of an LLVM enum attribute.
194 class LLVMEnumAttr : public tblgen::EnumAttr {
195 public:
196   using tblgen::EnumAttr::EnumAttr;
197 
198   // Returns the C++ enum name for the LLVM API.
199   StringRef getLLVMClassName() const {
200     return def->getValueAsString("llvmClassName");
201   }
202 
203   // Returns all associated cases viewed as LLVM-specific enum cases.
204   std::vector<LLVMEnumAttrCase> getAllCases() const {
205     std::vector<LLVMEnumAttrCase> cases;
206 
207     for (auto &c : tblgen::EnumAttr::getAllCases())
208       cases.emplace_back(c);
209 
210     return cases;
211   }
212 };
213 } // namespace
214 
215 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
216 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
217 // (Enum) to the corresponding LLVM API enumerant
218 static void emitOneEnumToConversion(const llvm::Record *record,
219                                     raw_ostream &os) {
220   LLVMEnumAttr enumAttr(record);
221   StringRef llvmClass = enumAttr.getLLVMClassName();
222   StringRef cppClassName = enumAttr.getEnumClassName();
223   StringRef cppNamespace = enumAttr.getCppNamespace();
224 
225   // Emit the function converting the enum attribute to its LLVM counterpart.
226   os << formatv(
227       "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n",
228       llvmClass, cppClassName, cppNamespace);
229   os << "  switch (value) {\n";
230 
231   for (const auto &enumerant : enumAttr.getAllCases()) {
232     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
233     StringRef cppEnumerant = enumerant.getSymbol();
234     os << formatv("  case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
235                   cppEnumerant);
236     os << formatv("    return {0}::{1};\n", llvmClass, llvmEnumerant);
237   }
238 
239   os << "  }\n";
240   os << formatv("  llvm_unreachable(\"unknown {0} type\");\n",
241                 enumAttr.getEnumClassName());
242   os << "}\n\n";
243 }
244 
245 // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
246 // containing switch-based logic to convert from the LLVM API enumerant to MLIR
247 // LLVM dialect enum attribute (Enum).
248 static void emitOneEnumFromConversion(const llvm::Record *record,
249                                       raw_ostream &os) {
250   LLVMEnumAttr enumAttr(record);
251   StringRef llvmClass = enumAttr.getLLVMClassName();
252   StringRef cppClassName = enumAttr.getEnumClassName();
253   StringRef cppNamespace = enumAttr.getCppNamespace();
254 
255   // Emit the function converting the enum attribute from its LLVM counterpart.
256   os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
257                 "value) {{\n",
258                 cppNamespace, cppClassName, llvmClass);
259   os << "  switch (value) {\n";
260 
261   for (const auto &enumerant : enumAttr.getAllCases()) {
262     StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
263     StringRef cppEnumerant = enumerant.getSymbol();
264     os << formatv("  case {0}::{1}:\n", llvmClass, llvmEnumerant);
265     os << formatv("    return {0}::{1}::{2};\n", cppNamespace, cppClassName,
266                   cppEnumerant);
267   }
268 
269   os << "  }\n";
270   os << formatv("  llvm_unreachable(\"unknown {0} type\");",
271                 enumAttr.getLLVMClassName());
272   os << "}\n\n";
273 }
274 
275 // Emits conversion functions between MLIR enum attribute case and corresponding
276 // LLVM API enumerants for all registered LLVM dialect enum attributes.
277 template <bool ConvertTo>
278 static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
279                                    raw_ostream &os) {
280   for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
281     if (ConvertTo)
282       emitOneEnumToConversion(def, os);
283     else
284       emitOneEnumFromConversion(def, os);
285 
286   return false;
287 }
288 
289 static mlir::GenRegistration
290     genLLVMIRConversions("gen-llvmir-conversions",
291                          "Generate LLVM IR conversions", emitBuilders);
292 
293 static mlir::GenRegistration
294     genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
295                             "Generate conversions of EnumAttrs to LLVM IR",
296                             emitEnumConversionDefs</*ConvertTo=*/true>);
297 
298 static mlir::GenRegistration
299     genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
300                               "Generate conversions of EnumAttrs from LLVM IR",
301                               emitEnumConversionDefs</*ConvertTo=*/false>);
302