1 //===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file uses tablegen definitions of the LLVM IR Dialect operations to
10 // generate the code building the LLVM IR from it.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Support/LogicalResult.h"
15 #include "mlir/TableGen/Attribute.h"
16 #include "mlir/TableGen/GenInfo.h"
17 #include "mlir/TableGen/Operator.h"
18
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/Twine.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "llvm/TableGen/Record.h"
24 #include "llvm/TableGen/TableGenBackend.h"
25
26 using namespace llvm;
27 using namespace mlir;
28
emitError(const Twine & message)29 static bool emitError(const Twine &message) {
30 llvm::errs() << message << "\n";
31 return false;
32 }
33
34 namespace {
35 // Helper structure to return a position of the substring in a string.
36 struct StringLoc {
37 size_t pos;
38 size_t length;
39
40 // Take a substring identified by this location in the given string.
in__anon55b832950111::StringLoc41 StringRef in(StringRef str) const { return str.substr(pos, length); }
42
43 // A location is invalid if its position is outside the string.
operator bool__anon55b832950111::StringLoc44 explicit operator bool() { return pos != std::string::npos; }
45 };
46 } // namespace
47
48 // Find the next TableGen variable in the given pattern. These variables start
49 // with a `$` character and can contain alphanumeric characters or underscores.
50 // Return the position of the variable in the pattern and its length, including
51 // the `$` character. The escape syntax `$$` is also detected and returned.
findNextVariable(StringRef str)52 static StringLoc findNextVariable(StringRef str) {
53 size_t startPos = str.find('$');
54 if (startPos == std::string::npos)
55 return {startPos, 0};
56
57 // If we see "$$", return immediately.
58 if (startPos != str.size() - 1 && str[startPos + 1] == '$')
59 return {startPos, 2};
60
61 // Otherwise, the symbol spans until the first character that is not
62 // alphanumeric or '_'.
63 size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
64 startPos + 1);
65 if (endPos == std::string::npos)
66 endPos = str.size();
67
68 return {startPos, endPos - startPos};
69 }
70
71 // Check if `name` is the name of the variadic operand of `op`. The variadic
72 // operand can only appear at the last position in the list of operands.
isVariadicOperandName(const tblgen::Operator & op,StringRef name)73 static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
74 unsigned numOperands = op.getNumOperands();
75 if (numOperands == 0)
76 return false;
77 const auto &operand = op.getOperand(numOperands - 1);
78 return operand.isVariableLength() && operand.name == name;
79 }
80
81 // Check if `result` is a known name of a result of `op`.
isResultName(const tblgen::Operator & op,StringRef name)82 static bool isResultName(const tblgen::Operator &op, StringRef name) {
83 for (int i = 0, e = op.getNumResults(); i < e; ++i)
84 if (op.getResultName(i) == name)
85 return true;
86 return false;
87 }
88
89 // Check if `name` is a known name of an attribute of `op`.
isAttributeName(const tblgen::Operator & op,StringRef name)90 static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
91 return llvm::any_of(
92 op.getAttributes(),
93 [name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
94 }
95
96 // Check if `name` is a known name of an operand of `op`.
isOperandName(const tblgen::Operator & op,StringRef name)97 static bool isOperandName(const tblgen::Operator &op, StringRef name) {
98 for (int i = 0, e = op.getNumOperands(); i < e; ++i)
99 if (op.getOperand(i).name == name)
100 return true;
101 return false;
102 }
103
104 // Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
105 // for one definition of an LLVM IR Dialect operation. Return true on success.
emitOneBuilder(const Record & record,raw_ostream & os)106 static bool emitOneBuilder(const Record &record, raw_ostream &os) {
107 auto op = tblgen::Operator(record);
108
109 if (!record.getValue("llvmBuilder"))
110 return emitError("no 'llvmBuilder' field for op " + op.getOperationName());
111
112 // Return early if there is no builder specified.
113 auto builderStrRef = record.getValueAsString("llvmBuilder");
114 if (builderStrRef.empty())
115 return true;
116
117 // Progressively create the builder string by replacing $-variables with
118 // value lookups. Keep only the not-yet-traversed part of the builder pattern
119 // to avoid re-traversing the string multiple times.
120 std::string builder;
121 llvm::raw_string_ostream bs(builder);
122 while (auto loc = findNextVariable(builderStrRef)) {
123 auto name = loc.in(builderStrRef).drop_front();
124 auto getterName = op.getGetterName(name);
125 // First, insert the non-matched part as is.
126 bs << builderStrRef.substr(0, loc.pos);
127 // Then, rewrite the name based on its kind.
128 bool isVariadicOperand = isVariadicOperandName(op, name);
129 if (isOperandName(op, name)) {
130 auto result =
131 isVariadicOperand
132 ? formatv("moduleTranslation.lookupValues(op.{0}())", getterName)
133 : formatv("moduleTranslation.lookupValue(op.{0}())", getterName);
134 bs << result;
135 } else if (isAttributeName(op, name)) {
136 bs << formatv("op.{0}()", getterName);
137 } else if (isResultName(op, name)) {
138 bs << formatv("moduleTranslation.mapValue(op.{0}())", getterName);
139 } else if (name == "_resultType") {
140 bs << "moduleTranslation.convertType(op.getResult().getType())";
141 } else if (name == "_hasResult") {
142 bs << "opInst.getNumResults() == 1";
143 } else if (name == "_location") {
144 bs << "opInst.getLoc()";
145 } else if (name == "_numOperands") {
146 bs << "opInst.getNumOperands()";
147 } else if (name == "$") {
148 bs << '$';
149 } else {
150 return emitError(name + " is neither an argument nor a result of " +
151 op.getOperationName());
152 }
153 // Finally, only keep the untraversed part of the string.
154 builderStrRef = builderStrRef.substr(loc.pos + loc.length);
155 }
156
157 // Output the check and the rewritten builder string.
158 os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
159 << ">(opInst)) {\n";
160 os << bs.str() << builderStrRef << "\n";
161 os << " return success();\n";
162 os << "}\n";
163
164 return true;
165 }
166
167 // Emit all builders. Returns false on success because of the generator
168 // registration requirements.
emitBuilders(const RecordKeeper & recordKeeper,raw_ostream & os)169 static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
170 for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
171 if (!emitOneBuilder(*def, os))
172 return true;
173 }
174 return false;
175 }
176
177 namespace {
178 // Wrapper class around a Tablegen definition of an LLVM enum attribute case.
179 class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
180 public:
181 using tblgen::EnumAttrCase::EnumAttrCase;
182
183 // Constructs a case from a non LLVM-specific enum attribute case.
LLVMEnumAttrCase(const tblgen::EnumAttrCase & other)184 explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
185 : tblgen::EnumAttrCase(&other.getDef()) {}
186
187 // Returns the C++ enumerant for the LLVM API.
getLLVMEnumerant() const188 StringRef getLLVMEnumerant() const {
189 return def->getValueAsString("llvmEnumerant");
190 }
191 };
192
193 // Wraper class around a Tablegen definition of an LLVM enum attribute.
194 class LLVMEnumAttr : public tblgen::EnumAttr {
195 public:
196 using tblgen::EnumAttr::EnumAttr;
197
198 // Returns the C++ enum name for the LLVM API.
getLLVMClassName() const199 StringRef getLLVMClassName() const {
200 return def->getValueAsString("llvmClassName");
201 }
202
203 // Returns all associated cases viewed as LLVM-specific enum cases.
getAllCases() const204 std::vector<LLVMEnumAttrCase> getAllCases() const {
205 std::vector<LLVMEnumAttrCase> cases;
206
207 for (auto &c : tblgen::EnumAttr::getAllCases())
208 cases.emplace_back(c);
209
210 return cases;
211 }
212 };
213
214 // Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
215 class LLVMCEnumAttr : public tblgen::EnumAttr {
216 public:
217 using tblgen::EnumAttr::EnumAttr;
218
219 // Returns the C++ enum name for the LLVM API.
getLLVMClassName() const220 StringRef getLLVMClassName() const {
221 return def->getValueAsString("llvmClassName");
222 }
223
224 // Returns all associated cases viewed as LLVM-specific enum cases.
getAllCases() const225 std::vector<LLVMEnumAttrCase> getAllCases() const {
226 std::vector<LLVMEnumAttrCase> cases;
227
228 for (auto &c : tblgen::EnumAttr::getAllCases())
229 cases.emplace_back(c);
230
231 return cases;
232 }
233 };
234 } // namespace
235
236 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
237 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
238 // (Enum) to the corresponding LLVM API enumerant
emitOneEnumToConversion(const llvm::Record * record,raw_ostream & os)239 static void emitOneEnumToConversion(const llvm::Record *record,
240 raw_ostream &os) {
241 LLVMEnumAttr enumAttr(record);
242 StringRef llvmClass = enumAttr.getLLVMClassName();
243 StringRef cppClassName = enumAttr.getEnumClassName();
244 StringRef cppNamespace = enumAttr.getCppNamespace();
245
246 // Emit the function converting the enum attribute to its LLVM counterpart.
247 os << formatv(
248 "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n",
249 llvmClass, cppClassName, cppNamespace);
250 os << " switch (value) {\n";
251
252 for (const auto &enumerant : enumAttr.getAllCases()) {
253 StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
254 StringRef cppEnumerant = enumerant.getSymbol();
255 os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
256 cppEnumerant);
257 os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant);
258 }
259
260 os << " }\n";
261 os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
262 enumAttr.getEnumClassName());
263 os << "}\n\n";
264 }
265
266 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
267 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case
268 // (Enum) to the corresponding LLVM API C-style enumerant
emitOneCEnumToConversion(const llvm::Record * record,raw_ostream & os)269 static void emitOneCEnumToConversion(const llvm::Record *record,
270 raw_ostream &os) {
271 LLVMCEnumAttr enumAttr(record);
272 StringRef llvmClass = enumAttr.getLLVMClassName();
273 StringRef cppClassName = enumAttr.getEnumClassName();
274 StringRef cppNamespace = enumAttr.getCppNamespace();
275
276 // Emit the function converting the enum attribute to its LLVM counterpart.
277 os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
278 "convert{0}ToLLVM({1}::{0} value) {{\n",
279 cppClassName, cppNamespace);
280 os << " switch (value) {\n";
281
282 for (const auto &enumerant : enumAttr.getAllCases()) {
283 StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
284 StringRef cppEnumerant = enumerant.getSymbol();
285 os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
286 cppEnumerant);
287 os << formatv(" return static_cast<int64_t>({0}::{1});\n", llvmClass,
288 llvmEnumerant);
289 }
290
291 os << " }\n";
292 os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
293 enumAttr.getEnumClassName());
294 os << "}\n\n";
295 }
296
297 // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
298 // containing switch-based logic to convert from the LLVM API enumerant to MLIR
299 // LLVM dialect enum attribute (Enum).
emitOneEnumFromConversion(const llvm::Record * record,raw_ostream & os)300 static void emitOneEnumFromConversion(const llvm::Record *record,
301 raw_ostream &os) {
302 LLVMEnumAttr enumAttr(record);
303 StringRef llvmClass = enumAttr.getLLVMClassName();
304 StringRef cppClassName = enumAttr.getEnumClassName();
305 StringRef cppNamespace = enumAttr.getCppNamespace();
306
307 // Emit the function converting the enum attribute from its LLVM counterpart.
308 os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
309 "value) {{\n",
310 cppNamespace, cppClassName, llvmClass);
311 os << " switch (value) {\n";
312
313 for (const auto &enumerant : enumAttr.getAllCases()) {
314 StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
315 StringRef cppEnumerant = enumerant.getSymbol();
316 os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
317 os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
318 cppEnumerant);
319 }
320
321 os << " }\n";
322 os << formatv(" llvm_unreachable(\"unknown {0} type\");",
323 enumAttr.getLLVMClassName());
324 os << "}\n\n";
325 }
326
327 // Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
328 // containing switch-based logic to convert from the LLVM API C-style enumerant
329 // to MLIR LLVM dialect enum attribute (Enum).
emitOneCEnumFromConversion(const llvm::Record * record,raw_ostream & os)330 static void emitOneCEnumFromConversion(const llvm::Record *record,
331 raw_ostream &os) {
332 LLVMCEnumAttr enumAttr(record);
333 StringRef llvmClass = enumAttr.getLLVMClassName();
334 StringRef cppClassName = enumAttr.getEnumClassName();
335 StringRef cppNamespace = enumAttr.getCppNamespace();
336
337 // Emit the function converting the enum attribute from its LLVM counterpart.
338 os << formatv(
339 "inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
340 "value) {{\n",
341 cppNamespace, cppClassName, llvmClass);
342 os << " switch (value) {\n";
343
344 for (const auto &enumerant : enumAttr.getAllCases()) {
345 StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
346 StringRef cppEnumerant = enumerant.getSymbol();
347 os << formatv(" case static_cast<int64_t>({0}::{1}):\n", llvmClass,
348 llvmEnumerant);
349 os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
350 cppEnumerant);
351 }
352
353 os << " }\n";
354 os << formatv(" llvm_unreachable(\"unknown {0} type\");",
355 enumAttr.getLLVMClassName());
356 os << "}\n\n";
357 }
358
359 // Emits conversion functions between MLIR enum attribute case and corresponding
360 // LLVM API enumerants for all registered LLVM dialect enum attributes.
361 template <bool ConvertTo>
emitEnumConversionDefs(const RecordKeeper & recordKeeper,raw_ostream & os)362 static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
363 raw_ostream &os) {
364 for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
365 if (ConvertTo)
366 emitOneEnumToConversion(def, os);
367 else
368 emitOneEnumFromConversion(def, os);
369
370 for (const auto *def :
371 recordKeeper.getAllDerivedDefinitions("LLVM_CEnumAttr"))
372 if (ConvertTo)
373 emitOneCEnumToConversion(def, os);
374 else
375 emitOneCEnumFromConversion(def, os);
376
377 return false;
378 }
379
emitIntrOpPair(const Record & record,raw_ostream & os)380 static void emitIntrOpPair(const Record &record, raw_ostream &os) {
381 auto op = tblgen::Operator(record);
382 os << "{llvm::Intrinsic::" << record.getValueAsString("llvmEnumName") << ", "
383 << op.getQualCppClassName() << "::getOperationName()},\n";
384 }
385
emitIntrOpPairs(const RecordKeeper & recordKeeper,raw_ostream & os)386 static bool emitIntrOpPairs(const RecordKeeper &recordKeeper, raw_ostream &os) {
387 for (const auto *def :
388 recordKeeper.getAllDerivedDefinitions("LLVM_IntrOpBase"))
389 emitIntrOpPair(*def, os);
390
391 return false;
392 }
393
394 static mlir::GenRegistration
395 genLLVMIRConversions("gen-llvmir-conversions",
396 "Generate LLVM IR conversions", emitBuilders);
397
398 static mlir::GenRegistration
399 genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
400 "Generate conversions of EnumAttrs to LLVM IR",
401 emitEnumConversionDefs</*ConvertTo=*/true>);
402
403 static mlir::GenRegistration
404 genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
405 "Generate conversions of EnumAttrs from LLVM IR",
406 emitEnumConversionDefs</*ConvertTo=*/false>);
407
408 static mlir::GenRegistration
409 genLLVMIntrinsicToOpPairs("gen-llvmintrinsic-to-llvmirop-pairs",
410 "Generate LLVM intrinsic to LLVMIR op pairs",
411 emitIntrOpPairs);
412