1 //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python 10 // binding classes wrapping a generic operation API. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/TableGen/GenInfo.h" 15 #include "mlir/TableGen/Operator.h" 16 #include "llvm/ADT/StringSet.h" 17 #include "llvm/Support/CommandLine.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/TableGen/Error.h" 20 #include "llvm/TableGen/Record.h" 21 22 using namespace mlir; 23 using namespace mlir::tblgen; 24 25 /// File header and includes. 26 constexpr const char *fileHeader = R"Py( 27 # Autogenerated by mlir-tblgen; don't manually edit. 28 29 from . import _cext 30 from . import _segmented_accessor, _equally_sized_accessor 31 _ir = _cext.ir 32 )Py"; 33 34 /// Template for dialect class: 35 /// {0} is the dialect namespace. 36 constexpr const char *dialectClassTemplate = R"Py( 37 @_cext.register_dialect 38 class _Dialect(_ir.Dialect): 39 DIALECT_NAMESPACE = "{0}" 40 pass 41 42 )Py"; 43 44 /// Template for operation class: 45 /// {0} is the Python class name; 46 /// {1} is the operation name. 47 constexpr const char *opClassTemplate = R"Py( 48 @_cext.register_operation(_Dialect) 49 class {0}(_ir.OpView): 50 OPERATION_NAME = "{1}" 51 )Py"; 52 53 /// Template for single-element accessor: 54 /// {0} is the name of the accessor; 55 /// {1} is either 'operand' or 'result'; 56 /// {2} is the position in the element list. 57 constexpr const char *opSingleTemplate = R"Py( 58 @property 59 def {0}(self): 60 return self.operation.{1}s[{2}] 61 )Py"; 62 63 /// Template for single-element accessor after a variable-length group: 64 /// {0} is the name of the accessor; 65 /// {1} is either 'operand' or 'result'; 66 /// {2} is the total number of element groups; 67 /// {3} is the position of the current group in the group list. 68 /// This works for both a single variadic group (non-negative length) and an 69 /// single optional element (zero length if the element is absent). 70 constexpr const char *opSingleAfterVariableTemplate = R"Py( 71 @property 72 def {0}(self): 73 variadic_group_length = len(self.operation.{1}s) - {2} + 1 74 return self.operation.{1}s[{3} + variadic_group_length - 1] 75 )Py"; 76 77 /// Template for an optional element accessor: 78 /// {0} is the name of the accessor; 79 /// {1} is either 'operand' or 'result'; 80 /// {2} is the total number of element groups; 81 /// {3} is the position of the current group in the group list. 82 constexpr const char *opOneOptionalTemplate = R"Py( 83 @property 84 def {0}(self); 85 return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} 86 else None 87 )Py"; 88 89 /// Template for the variadic group accessor in the single variadic group case: 90 /// {0} is the name of the accessor; 91 /// {1} is either 'operand' or 'result'; 92 /// {2} is the total number of element groups; 93 /// {3} is the position of the current group in the group list. 94 constexpr const char *opOneVariadicTemplate = R"Py( 95 @property 96 def {0}(self): 97 variadic_group_length = len(self.operation.{1}s) - {2} + 1 98 return self.operation.{1}s[{3}:{3} + variadic_group_length] 99 )Py"; 100 101 /// First part of the template for equally-sized variadic group accessor: 102 /// {0} is the name of the accessor; 103 /// {1} is either 'operand' or 'result'; 104 /// {2} is the total number of variadic groups; 105 /// {3} is the number of non-variadic groups preceding the current group; 106 /// {3} is the number of variadic groups preceding the current group. 107 constexpr const char *opVariadicEqualPrefixTemplate = R"Py( 108 @property 109 def {0}(self): 110 start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; 111 112 /// Second part of the template for equally-sized case, accessing a single 113 /// element: 114 /// {0} is either 'operand' or 'result'. 115 constexpr const char *opVariadicEqualSimpleTemplate = R"Py( 116 return self.operation.{0}s[start] 117 )Py"; 118 119 /// Second part of the template for equally-sized case, accessing a variadic 120 /// group: 121 /// {0} is either 'operand' or 'result'. 122 constexpr const char *opVariadicEqualVariadicTemplate = R"Py( 123 return self.operation.{0}s[start:start + pg] 124 )Py"; 125 126 /// Template for an attribute-sized group accessor: 127 /// {0} is the name of the accessor; 128 /// {1} is either 'operand' or 'result'; 129 /// {2} is the position of the group in the group list; 130 /// {3} is a return suffix (expected [0] for single-element, empty for 131 /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). 132 constexpr const char *opVariadicSegmentTemplate = R"Py( 133 @property 134 def {0}(self): 135 {1}_range = _segmented_accessor( 136 self.operation.{1}s, 137 self.operation.attributes["{1}_segment_sizes"], {2}) 138 return {1}_range{3} 139 )Py"; 140 141 /// Template for a suffix when accessing an optional element in the 142 /// attribute-sized case: 143 /// {0} is either 'operand' or 'result'; 144 constexpr const char *opVariadicSegmentOptionalTrailingTemplate = 145 R"Py([0] if len({0}_range) > 0 else None)Py"; 146 147 static llvm::cl::OptionCategory 148 clOpPythonBindingCat("Options for -gen-python-op-bindings"); 149 150 static llvm::cl::opt<std::string> 151 clDialectName("bind-dialect", 152 llvm::cl::desc("The dialect to run the generator for"), 153 llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); 154 155 /// Checks whether `str` is a Python keyword. 156 static bool isPythonKeyword(StringRef str) { 157 static llvm::StringSet<> keywords( 158 {"and", "as", "assert", "break", "class", "continue", 159 "def", "del", "elif", "else", "except", "finally", 160 "for", "from", "global", "if", "import", "in", 161 "is", "lambda", "nonlocal", "not", "or", "pass", 162 "raise", "return", "try", "while", "with", "yield"}); 163 return keywords.contains(str); 164 }; 165 166 /// Modifies the `name` in a way that it becomes suitable for Python bindings 167 /// (does not change the `name` if it already is suitable) and returns the 168 /// modified version. 169 static std::string sanitizeName(StringRef name) { 170 if (isPythonKeyword(name)) 171 return (name + "_").str(); 172 return name.str(); 173 } 174 175 /// Emits accessors to "elements" of an Op definition. Currently, the supported 176 /// elements are operands and results, indicated by `kind`, which must be either 177 /// `operand` or `result` and is used verbatim in the emitted code. 178 static void emitElementAccessors( 179 const Operator &op, raw_ostream &os, const char *kind, 180 llvm::function_ref<unsigned(const Operator &)> getNumVariadic, 181 llvm::function_ref<int(const Operator &)> getNumElements, 182 llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> 183 getElement) { 184 assert(llvm::is_contained( 185 llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && 186 "unsupported kind"); 187 188 // Traits indicating how to process variadic elements. 189 std::string sameSizeTrait = 190 llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", 191 llvm::StringRef(kind).take_front().upper(), 192 llvm::StringRef(kind).drop_front()); 193 std::string attrSizedTrait = 194 llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", 195 llvm::StringRef(kind).take_front().upper(), 196 llvm::StringRef(kind).drop_front()); 197 198 unsigned numVariadic = getNumVariadic(op); 199 200 // If there is only one variadic element group, its size can be inferred from 201 // the total number of elements. If there are none, the generation is 202 // straightforward. 203 if (numVariadic <= 1) { 204 bool seenVariableLength = false; 205 for (int i = 0, e = getNumElements(op); i < e; ++i) { 206 const NamedTypeConstraint &element = getElement(op, i); 207 if (element.isVariableLength()) 208 seenVariableLength = true; 209 if (element.name.empty()) 210 continue; 211 if (element.isVariableLength()) { 212 os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate 213 : opOneVariadicTemplate, 214 sanitizeName(element.name), kind, 215 getNumElements(op), i); 216 } else if (seenVariableLength) { 217 os << llvm::formatv(opSingleAfterVariableTemplate, 218 sanitizeName(element.name), kind, 219 getNumElements(op), i); 220 } else { 221 os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, 222 i); 223 } 224 } 225 return; 226 } 227 228 // Handle the operations where variadic groups have the same size. 229 if (op.getTrait(sameSizeTrait)) { 230 int numPrecedingSimple = 0; 231 int numPrecedingVariadic = 0; 232 for (int i = 0, e = getNumElements(op); i < e; ++i) { 233 const NamedTypeConstraint &element = getElement(op, i); 234 if (!element.name.empty()) { 235 os << llvm::formatv(opVariadicEqualPrefixTemplate, 236 sanitizeName(element.name), kind, numVariadic, 237 numPrecedingSimple, numPrecedingVariadic); 238 os << llvm::formatv(element.isVariableLength() 239 ? opVariadicEqualVariadicTemplate 240 : opVariadicEqualSimpleTemplate, 241 kind); 242 } 243 if (element.isVariableLength()) 244 ++numPrecedingVariadic; 245 else 246 ++numPrecedingSimple; 247 } 248 return; 249 } 250 251 // Handle the operations where the size of groups (variadic or not) is 252 // provided as an attribute. For non-variadic elements, make sure to return 253 // an element rather than a singleton container. 254 if (op.getTrait(attrSizedTrait)) { 255 for (int i = 0, e = getNumElements(op); i < e; ++i) { 256 const NamedTypeConstraint &element = getElement(op, i); 257 if (element.name.empty()) 258 continue; 259 std::string trailing; 260 if (!element.isVariableLength()) 261 trailing = "[0]"; 262 else if (element.isOptional()) 263 trailing = std::string( 264 llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); 265 os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), 266 kind, i, trailing); 267 } 268 return; 269 } 270 271 llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); 272 } 273 274 /// Emits accessor to Op operands. 275 static void emitOperandAccessors(const Operator &op, raw_ostream &os) { 276 auto getNumVariadic = [](const Operator &oper) { 277 return oper.getNumVariableLengthOperands(); 278 }; 279 auto getNumElements = [](const Operator &oper) { 280 return oper.getNumOperands(); 281 }; 282 auto getElement = [](const Operator &oper, 283 int i) -> const NamedTypeConstraint & { 284 return oper.getOperand(i); 285 }; 286 emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements, 287 getElement); 288 } 289 290 /// Emits access or Op results. 291 static void emitResultAccessors(const Operator &op, raw_ostream &os) { 292 auto getNumVariadic = [](const Operator &oper) { 293 return oper.getNumVariableLengthResults(); 294 }; 295 auto getNumElements = [](const Operator &oper) { 296 return oper.getNumResults(); 297 }; 298 auto getElement = [](const Operator &oper, 299 int i) -> const NamedTypeConstraint & { 300 return oper.getResult(i); 301 }; 302 emitElementAccessors(op, os, "result", getNumVariadic, getNumElements, 303 getElement); 304 } 305 306 /// Emits bindings for a specific Op to the given output stream. 307 static void emitOpBindings(const Operator &op, raw_ostream &os) { 308 os << llvm::formatv(opClassTemplate, op.getCppClassName(), 309 op.getOperationName()); 310 emitOperandAccessors(op, os); 311 emitResultAccessors(op, os); 312 } 313 314 /// Emits bindings for the dialect specified in the command line, including file 315 /// headers and utilities. Returns `false` on success to comply with Tablegen 316 /// registration requirements. 317 static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { 318 if (clDialectName.empty()) 319 llvm::PrintFatalError("dialect name not provided"); 320 321 os << fileHeader; 322 os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); 323 for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { 324 Operator op(rec); 325 if (op.getDialectName() == clDialectName.getValue()) 326 emitOpBindings(op, os); 327 } 328 return false; 329 } 330 331 static GenRegistration 332 genPythonBindings("gen-python-op-bindings", 333 "Generate Python bindings for MLIR Ops", &emitAllOps); 334