1 //===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file uses tablegen definitions of the LLVM IR Dialect operations to 10 // generate the code building the LLVM IR from it. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Support/LogicalResult.h" 15 #include "mlir/TableGen/Attribute.h" 16 #include "mlir/TableGen/GenInfo.h" 17 #include "mlir/TableGen/Operator.h" 18 19 #include "llvm/ADT/StringExtras.h" 20 #include "llvm/ADT/Twine.h" 21 #include "llvm/Support/FormatVariadic.h" 22 #include "llvm/Support/raw_ostream.h" 23 #include "llvm/TableGen/Record.h" 24 #include "llvm/TableGen/TableGenBackend.h" 25 26 using namespace llvm; 27 using namespace mlir; 28 29 static bool emitError(const Twine &message) { 30 llvm::errs() << message << "\n"; 31 return false; 32 } 33 34 namespace { 35 // Helper structure to return a position of the substring in a string. 36 struct StringLoc { 37 size_t pos; 38 size_t length; 39 40 // Take a substring identified by this location in the given string. 41 StringRef in(StringRef str) const { return str.substr(pos, length); } 42 43 // A location is invalid if its position is outside the string. 44 explicit operator bool() { return pos != std::string::npos; } 45 }; 46 } // namespace 47 48 // Find the next TableGen variable in the given pattern. These variables start 49 // with a `$` character and can contain alphanumeric characters or underscores. 50 // Return the position of the variable in the pattern and its length, including 51 // the `$` character. The escape syntax `$$` is also detected and returned. 52 static StringLoc findNextVariable(StringRef str) { 53 size_t startPos = str.find('$'); 54 if (startPos == std::string::npos) 55 return {startPos, 0}; 56 57 // If we see "$$", return immediately. 58 if (startPos != str.size() - 1 && str[startPos + 1] == '$') 59 return {startPos, 2}; 60 61 // Otherwise, the symbol spans until the first character that is not 62 // alphanumeric or '_'. 63 size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; }, 64 startPos + 1); 65 if (endPos == std::string::npos) 66 endPos = str.size(); 67 68 return {startPos, endPos - startPos}; 69 } 70 71 // Check if `name` is the name of the variadic operand of `op`. The variadic 72 // operand can only appear at the last position in the list of operands. 73 static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) { 74 unsigned numOperands = op.getNumOperands(); 75 if (numOperands == 0) 76 return false; 77 const auto &operand = op.getOperand(numOperands - 1); 78 return operand.isVariableLength() && operand.name == name; 79 } 80 81 // Check if `result` is a known name of a result of `op`. 82 static bool isResultName(const tblgen::Operator &op, StringRef name) { 83 for (int i = 0, e = op.getNumResults(); i < e; ++i) 84 if (op.getResultName(i) == name) 85 return true; 86 return false; 87 } 88 89 // Check if `name` is a known name of an attribute of `op`. 90 static bool isAttributeName(const tblgen::Operator &op, StringRef name) { 91 return llvm::any_of( 92 op.getAttributes(), 93 [name](const tblgen::NamedAttribute &attr) { return attr.name == name; }); 94 } 95 96 // Check if `name` is a known name of an operand of `op`. 97 static bool isOperandName(const tblgen::Operator &op, StringRef name) { 98 for (int i = 0, e = op.getNumOperands(); i < e; ++i) 99 if (op.getOperand(i).name == name) 100 return true; 101 return false; 102 } 103 104 // Emit to `os` the operator-name driven check and the call to LLVM IRBuilder 105 // for one definition of an LLVM IR Dialect operation. Return true on success. 106 static bool emitOneBuilder(const Record &record, raw_ostream &os) { 107 auto op = tblgen::Operator(record); 108 109 if (!record.getValue("llvmBuilder")) 110 return emitError("no 'llvmBuilder' field for op " + op.getOperationName()); 111 112 // Return early if there is no builder specified. 113 auto builderStrRef = record.getValueAsString("llvmBuilder"); 114 if (builderStrRef.empty()) 115 return true; 116 117 // Progressively create the builder string by replacing $-variables with 118 // value lookups. Keep only the not-yet-traversed part of the builder pattern 119 // to avoid re-traversing the string multiple times. 120 std::string builder; 121 llvm::raw_string_ostream bs(builder); 122 while (auto loc = findNextVariable(builderStrRef)) { 123 auto name = loc.in(builderStrRef).drop_front(); 124 auto getterName = op.getGetterName(name); 125 // First, insert the non-matched part as is. 126 bs << builderStrRef.substr(0, loc.pos); 127 // Then, rewrite the name based on its kind. 128 bool isVariadicOperand = isVariadicOperandName(op, name); 129 if (isOperandName(op, name)) { 130 auto result = 131 isVariadicOperand 132 ? formatv("moduleTranslation.lookupValues(op.{0}())", getterName) 133 : formatv("moduleTranslation.lookupValue(op.{0}())", getterName); 134 bs << result; 135 } else if (isAttributeName(op, name)) { 136 bs << formatv("op.{0}()", getterName); 137 } else if (isResultName(op, name)) { 138 bs << formatv("moduleTranslation.mapValue(op.{0}())", getterName); 139 } else if (name == "_resultType") { 140 bs << "moduleTranslation.convertType(op.getResult().getType())"; 141 } else if (name == "_hasResult") { 142 bs << "opInst.getNumResults() == 1"; 143 } else if (name == "_location") { 144 bs << "opInst.getLoc()"; 145 } else if (name == "_numOperands") { 146 bs << "opInst.getNumOperands()"; 147 } else if (name == "$") { 148 bs << '$'; 149 } else { 150 return emitError(name + " is neither an argument nor a result of " + 151 op.getOperationName()); 152 } 153 // Finally, only keep the untraversed part of the string. 154 builderStrRef = builderStrRef.substr(loc.pos + loc.length); 155 } 156 157 // Output the check and the rewritten builder string. 158 os << "if (auto op = dyn_cast<" << op.getQualCppClassName() 159 << ">(opInst)) {\n"; 160 os << bs.str() << builderStrRef << "\n"; 161 os << " return success();\n"; 162 os << "}\n"; 163 164 return true; 165 } 166 167 // Emit all builders. Returns false on success because of the generator 168 // registration requirements. 169 static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) { 170 for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) { 171 if (!emitOneBuilder(*def, os)) 172 return true; 173 } 174 return false; 175 } 176 177 namespace { 178 // Wrapper class around a Tablegen definition of an LLVM enum attribute case. 179 class LLVMEnumAttrCase : public tblgen::EnumAttrCase { 180 public: 181 using tblgen::EnumAttrCase::EnumAttrCase; 182 183 // Constructs a case from a non LLVM-specific enum attribute case. 184 explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other) 185 : tblgen::EnumAttrCase(&other.getDef()) {} 186 187 // Returns the C++ enumerant for the LLVM API. 188 StringRef getLLVMEnumerant() const { 189 return def->getValueAsString("llvmEnumerant"); 190 } 191 }; 192 193 // Wraper class around a Tablegen definition of an LLVM enum attribute. 194 class LLVMEnumAttr : public tblgen::EnumAttr { 195 public: 196 using tblgen::EnumAttr::EnumAttr; 197 198 // Returns the C++ enum name for the LLVM API. 199 StringRef getLLVMClassName() const { 200 return def->getValueAsString("llvmClassName"); 201 } 202 203 // Returns all associated cases viewed as LLVM-specific enum cases. 204 std::vector<LLVMEnumAttrCase> getAllCases() const { 205 std::vector<LLVMEnumAttrCase> cases; 206 207 for (auto &c : tblgen::EnumAttr::getAllCases()) 208 cases.emplace_back(c); 209 210 return cases; 211 } 212 }; 213 } // namespace 214 215 // Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing 216 // switch-based logic to convert from the MLIR LLVM dialect enum attribute case 217 // (Enum) to the corresponding LLVM API enumerant 218 static void emitOneEnumToConversion(const llvm::Record *record, 219 raw_ostream &os) { 220 LLVMEnumAttr enumAttr(record); 221 StringRef llvmClass = enumAttr.getLLVMClassName(); 222 StringRef cppClassName = enumAttr.getEnumClassName(); 223 StringRef cppNamespace = enumAttr.getCppNamespace(); 224 225 // Emit the function converting the enum attribute to its LLVM counterpart. 226 os << formatv( 227 "static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n", 228 llvmClass, cppClassName, cppNamespace); 229 os << " switch (value) {\n"; 230 231 for (const auto &enumerant : enumAttr.getAllCases()) { 232 StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); 233 StringRef cppEnumerant = enumerant.getSymbol(); 234 os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName, 235 cppEnumerant); 236 os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant); 237 } 238 239 os << " }\n"; 240 os << formatv(" llvm_unreachable(\"unknown {0} type\");\n", 241 enumAttr.getEnumClassName()); 242 os << "}\n\n"; 243 } 244 245 // Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and 246 // containing switch-based logic to convert from the LLVM API enumerant to MLIR 247 // LLVM dialect enum attribute (Enum). 248 static void emitOneEnumFromConversion(const llvm::Record *record, 249 raw_ostream &os) { 250 LLVMEnumAttr enumAttr(record); 251 StringRef llvmClass = enumAttr.getLLVMClassName(); 252 StringRef cppClassName = enumAttr.getEnumClassName(); 253 StringRef cppNamespace = enumAttr.getCppNamespace(); 254 255 // Emit the function converting the enum attribute from its LLVM counterpart. 256 os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} " 257 "value) {{\n", 258 cppNamespace, cppClassName, llvmClass); 259 os << " switch (value) {\n"; 260 261 for (const auto &enumerant : enumAttr.getAllCases()) { 262 StringRef llvmEnumerant = enumerant.getLLVMEnumerant(); 263 StringRef cppEnumerant = enumerant.getSymbol(); 264 os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant); 265 os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName, 266 cppEnumerant); 267 } 268 269 os << " }\n"; 270 os << formatv(" llvm_unreachable(\"unknown {0} type\");", 271 enumAttr.getLLVMClassName()); 272 os << "}\n\n"; 273 } 274 275 // Emits conversion functions between MLIR enum attribute case and corresponding 276 // LLVM API enumerants for all registered LLVM dialect enum attributes. 277 template <bool ConvertTo> 278 static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper, 279 raw_ostream &os) { 280 for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr")) 281 if (ConvertTo) 282 emitOneEnumToConversion(def, os); 283 else 284 emitOneEnumFromConversion(def, os); 285 286 return false; 287 } 288 289 static mlir::GenRegistration 290 genLLVMIRConversions("gen-llvmir-conversions", 291 "Generate LLVM IR conversions", emitBuilders); 292 293 static mlir::GenRegistration 294 genEnumToLLVMConversion("gen-enum-to-llvmir-conversions", 295 "Generate conversions of EnumAttrs to LLVM IR", 296 emitEnumConversionDefs</*ConvertTo=*/true>); 297 298 static mlir::GenRegistration 299 genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions", 300 "Generate conversions of EnumAttrs from LLVM IR", 301 emitEnumConversionDefs</*ConvertTo=*/false>); 302