1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate C++ 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "OpClass.h" 15 #include "OpFormatGen.h" 16 #include "OpGenHelpers.h" 17 #include "mlir/TableGen/Class.h" 18 #include "mlir/TableGen/CodeGenHelpers.h" 19 #include "mlir/TableGen/Format.h" 20 #include "mlir/TableGen/GenInfo.h" 21 #include "mlir/TableGen/Interfaces.h" 22 #include "mlir/TableGen/Operator.h" 23 #include "mlir/TableGen/SideEffects.h" 24 #include "mlir/TableGen/Trait.h" 25 #include "llvm/ADT/MapVector.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/Signals.h" 31 #include "llvm/TableGen/Error.h" 32 #include "llvm/TableGen/Record.h" 33 #include "llvm/TableGen/TableGenBackend.h" 34 35 #define DEBUG_TYPE "mlir-tblgen-opdefgen" 36 37 using namespace llvm; 38 using namespace mlir; 39 using namespace mlir::tblgen; 40 41 static const char *const tblgenNamePrefix = "tblgen_"; 42 static const char *const generatedArgName = "odsArg"; 43 static const char *const odsBuilder = "odsBuilder"; 44 static const char *const builderOpState = "odsState"; 45 46 /// Code for an Op to lookup an attribute. Uses cached identifiers. 47 /// 48 /// {0}: The attribute's getter name. 49 static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; 50 51 /// The logic to calculate the actual value range for a declared operand/result 52 /// of an op with variadic operands/results. Note that this logic is not for 53 /// general use; it assumes all variadic operands/results must have the same 54 /// number of values. 55 /// 56 /// {0}: The list of whether each declared operand/result is variadic. 57 /// {1}: The total number of non-variadic operands/results. 58 /// {2}: The total number of variadic operands/results. 59 /// {3}: The total number of actual values. 60 /// {4}: "operand" or "result". 61 static const char *const sameVariadicSizeValueRangeCalcCode = R"( 62 bool isVariadic[] = {{{0}}; 63 int prevVariadicCount = 0; 64 for (unsigned i = 0; i < index; ++i) 65 if (isVariadic[i]) ++prevVariadicCount; 66 67 // Calculate how many dynamic values a static variadic {4} corresponds to. 68 // This assumes all static variadic {4}s have the same dynamic value count. 69 int variadicSize = ({3} - {1}) / {2}; 70 // `index` passed in as the parameter is the static index which counts each 71 // {4} (variadic or not) as size 1. So here for each previous static variadic 72 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic 73 // value pack for this static {4} starts. 74 int start = index + (variadicSize - 1) * prevVariadicCount; 75 int size = isVariadic[index] ? variadicSize : 1; 76 return {{start, size}; 77 )"; 78 79 /// The logic to calculate the actual value range for a declared operand/result 80 /// of an op with variadic operands/results. Note that this logic is assumes 81 /// the op has an attribute specifying the size of each operand/result segment 82 /// (variadic or not). 83 /// 84 /// {0}: The name of the attribute specifying the segment sizes. 85 static const char *const adapterSegmentSizeAttrInitCode = R"( 86 assert(odsAttrs && "missing segment size attribute for op"); 87 auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); 88 )"; 89 static const char *const opSegmentSizeAttrInitCode = R"( 90 auto sizeAttr = 91 (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>(); 92 )"; 93 static const char *const attrSizedSegmentValueRangeCalcCode = R"( 94 const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>(); 95 if (sizeAttr.isSplat()) 96 return {*sizeAttrValueIt * index, *sizeAttrValueIt}; 97 98 unsigned start = 0; 99 for (unsigned i = 0; i < index; ++i) 100 start += sizeAttrValueIt[i]; 101 return {start, sizeAttrValueIt[index]}; 102 )"; 103 104 /// The logic to calculate the actual value range for a declared operand 105 /// of an op with variadic of variadic operands within the OpAdaptor. 106 /// 107 /// {0}: The name of the segment attribute. 108 /// {1}: The index of the main operand. 109 static const char *const variadicOfVariadicAdaptorCalcCode = R"( 110 auto tblgenTmpOperands = getODSOperands({1}); 111 auto sizeAttrValues = {0}().getValues<uint32_t>(); 112 auto sizeAttrIt = sizeAttrValues.begin(); 113 114 ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups; 115 for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{ 116 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt)); 117 tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt); 118 } 119 return tblgenTmpOperandGroups; 120 )"; 121 122 /// The logic to build a range of either operand or result values. 123 /// 124 /// {0}: The begin iterator of the actual values. 125 /// {1}: The call to generate the start and length of the value range. 126 static const char *const valueRangeReturnCode = R"( 127 auto valueRange = {1}; 128 return {{std::next({0}, valueRange.first), 129 std::next({0}, valueRange.first + valueRange.second)}; 130 )"; 131 132 /// A header for indicating code sections. 133 /// 134 /// {0}: Some text, or a class name. 135 /// {1}: Some text. 136 static const char *const opCommentHeader = R"( 137 //===----------------------------------------------------------------------===// 138 // {0} {1} 139 //===----------------------------------------------------------------------===// 140 141 )"; 142 143 //===----------------------------------------------------------------------===// 144 // Utility structs and functions 145 //===----------------------------------------------------------------------===// 146 147 // Replaces all occurrences of `match` in `str` with `substitute`. 148 static std::string replaceAllSubstrs(std::string str, const std::string &match, 149 const std::string &substitute) { 150 std::string::size_type scanLoc = 0, matchLoc = std::string::npos; 151 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) { 152 str = str.replace(matchLoc, match.size(), substitute); 153 scanLoc = matchLoc + substitute.size(); 154 } 155 return str; 156 } 157 158 // Returns whether the record has a value of the given name that can be returned 159 // via getValueAsString. 160 static inline bool hasStringAttribute(const Record &record, 161 StringRef fieldName) { 162 auto *valueInit = record.getValueInit(fieldName); 163 return isa<StringInit>(valueInit); 164 } 165 166 static std::string getArgumentName(const Operator &op, int index) { 167 const auto &operand = op.getOperand(index); 168 if (!operand.name.empty()) 169 return std::string(operand.name); 170 return std::string(formatv("{0}_{1}", generatedArgName, index)); 171 } 172 173 // Returns true if we can use unwrapped value for the given `attr` in builders. 174 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { 175 return attr.getReturnType() != attr.getStorageType() && 176 // We need to wrap the raw value into an attribute in the builder impl 177 // so we need to make sure that the attribute specifies how to do that. 178 !attr.getConstBuilderTemplate().empty(); 179 } 180 181 namespace { 182 /// Helper class to select between OpAdaptor and Op code templates. 183 class OpOrAdaptorHelper { 184 public: 185 OpOrAdaptorHelper(const Operator &op, bool emitForOp) 186 : op(op), emitForOp(emitForOp) {} 187 188 /// Object that wraps a functor in a stream operator for interop with 189 /// llvm::formatv. 190 class Formatter { 191 public: 192 template <typename Functor> 193 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {} 194 195 std::string str() const { 196 std::string result; 197 llvm::raw_string_ostream os(result); 198 os << *this; 199 return os.str(); 200 } 201 202 private: 203 std::function<raw_ostream &(raw_ostream &)> func; 204 205 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) { 206 return fmt.func(os); 207 } 208 }; 209 210 // Generate code for getting an attribute. 211 Formatter getAttr(StringRef attrName) const { 212 return [this, attrName](raw_ostream &os) -> raw_ostream & { 213 if (!emitForOp) 214 return os << formatv("odsAttrs.get(\"{0}\")", attrName); 215 return os << formatv(opGetAttr, op.getGetterName(attrName)); 216 }; 217 } 218 219 // Get the prefix code for emitting an error. 220 Formatter emitErrorPrefix() const { 221 return [this](raw_ostream &os) -> raw_ostream & { 222 if (emitForOp) 223 return os << "emitOpError("; 224 return os << formatv("emitError(loc, \"'{0}' op \"", 225 op.getOperationName()); 226 }; 227 } 228 229 // Get the call to get an operand or segment of operands. 230 Formatter getOperand(unsigned index) const { 231 return [this, index](raw_ostream &os) -> raw_ostream & { 232 return os << formatv(op.getOperand(index).isVariadic() 233 ? "this->getODSOperands({0})" 234 : "(*this->getODSOperands({0}).begin())", 235 index); 236 }; 237 } 238 239 // Get the call to get a result of segment of results. 240 Formatter getResult(unsigned index) const { 241 return [this, index](raw_ostream &os) -> raw_ostream & { 242 if (!emitForOp) 243 return os << "<no results should be generated>"; 244 return os << formatv(op.getResult(index).isVariadic() 245 ? "this->getODSResults({0})" 246 : "(*this->getODSResults({0}).begin())", 247 index); 248 }; 249 } 250 251 // Return whether an op instance is available. 252 bool isEmittingForOp() const { return emitForOp; } 253 254 // Return the ODS operation wrapper. 255 const Operator &getOp() const { return op; } 256 257 private: 258 // The operation ODS wrapper. 259 const Operator &op; 260 // True if code is being generate for an op. False for an adaptor. 261 const bool emitForOp; 262 }; 263 } // namespace 264 265 //===----------------------------------------------------------------------===// 266 // Op emitter 267 //===----------------------------------------------------------------------===// 268 269 namespace { 270 // Helper class to emit a record into the given output stream. 271 class OpEmitter { 272 public: 273 static void 274 emitDecl(const Operator &op, raw_ostream &os, 275 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 276 static void 277 emitDef(const Operator &op, raw_ostream &os, 278 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 279 280 private: 281 OpEmitter(const Operator &op, 282 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 283 284 void emitDecl(raw_ostream &os); 285 void emitDef(raw_ostream &os); 286 287 // Generate methods for accessing the attribute names of this operation. 288 void genAttrNameGetters(); 289 290 // Generates the OpAsmOpInterface for this operation if possible. 291 void genOpAsmInterface(); 292 293 // Generates the `getOperationName` method for this op. 294 void genOpNameGetter(); 295 296 // Generates getters for the attributes. 297 void genAttrGetters(); 298 299 // Generates setter for the attributes. 300 void genAttrSetters(); 301 302 // Generates removers for optional attributes. 303 void genOptionalAttrRemovers(); 304 305 // Generates getters for named operands. 306 void genNamedOperandGetters(); 307 308 // Generates setters for named operands. 309 void genNamedOperandSetters(); 310 311 // Generates getters for named results. 312 void genNamedResultGetters(); 313 314 // Generates getters for named regions. 315 void genNamedRegionGetters(); 316 317 // Generates getters for named successors. 318 void genNamedSuccessorGetters(); 319 320 // Generates builder methods for the operation. 321 void genBuilder(); 322 323 // Generates the build() method that takes each operand/attribute 324 // as a stand-alone parameter. 325 void genSeparateArgParamBuilder(); 326 327 // Generates the build() method that takes each operand/attribute as a 328 // stand-alone parameter. The generated build() method uses first operand's 329 // type as all results' types. 330 void genUseOperandAsResultTypeSeparateParamBuilder(); 331 332 // Generates the build() method that takes all operands/attributes 333 // collectively as one parameter. The generated build() method uses first 334 // operand's type as all results' types. 335 void genUseOperandAsResultTypeCollectiveParamBuilder(); 336 337 // Generates the build() method that takes aggregate operands/attributes 338 // parameters. This build() method uses inferred types as result types. 339 // Requires: The type needs to be inferable via InferTypeOpInterface. 340 void genInferredTypeCollectiveParamBuilder(); 341 342 // Generates the build() method that takes each operand/attribute as a 343 // stand-alone parameter. The generated build() method uses first attribute's 344 // type as all result's types. 345 void genUseAttrAsResultTypeBuilder(); 346 347 // Generates the build() method that takes all result types collectively as 348 // one parameter. Similarly for operands and attributes. 349 void genCollectiveParamBuilder(); 350 351 // The kind of parameter to generate for result types in builders. 352 enum class TypeParamKind { 353 None, // No result type in parameter list. 354 Separate, // A separate parameter for each result type. 355 Collective, // An ArrayRef<Type> for all result types. 356 }; 357 358 // The kind of parameter to generate for attributes in builders. 359 enum class AttrParamKind { 360 WrappedAttr, // A wrapped MLIR Attribute instance. 361 UnwrappedValue, // A raw value without MLIR Attribute wrapper. 362 }; 363 364 // Builds the parameter list for build() method of this op. This method writes 365 // to `paramList` the comma-separated parameter list and updates 366 // `resultTypeNames` with the names for parameters for specifying result 367 // types. `inferredAttributes` is populated with any attributes that are 368 // elided from the build list. The given `typeParamKind` and `attrParamKind` 369 // controls how result types and attributes are placed in the parameter list. 370 void buildParamList(SmallVectorImpl<MethodParameter> ¶mList, 371 llvm::StringSet<> &inferredAttributes, 372 SmallVectorImpl<std::string> &resultTypeNames, 373 TypeParamKind typeParamKind, 374 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); 375 376 // Adds op arguments and regions into operation state for build() methods. 377 void 378 genCodeForAddingArgAndRegionForBuilder(MethodBody &body, 379 llvm::StringSet<> &inferredAttributes, 380 bool isRawValueAttr = false); 381 382 // Generates canonicalizer declaration for the operation. 383 void genCanonicalizerDecls(); 384 385 // Generates the folder declaration for the operation. 386 void genFolderDecls(); 387 388 // Generates the parser for the operation. 389 void genParser(); 390 391 // Generates the printer for the operation. 392 void genPrinter(); 393 394 // Generates verify method for the operation. 395 void genVerifier(); 396 397 // Generates verify statements for operands and results in the operation. 398 // The generated code will be attached to `body`. 399 void genOperandResultVerifier(MethodBody &body, 400 Operator::const_value_range values, 401 StringRef valueKind); 402 403 // Generates verify statements for regions in the operation. 404 // The generated code will be attached to `body`. 405 void genRegionVerifier(MethodBody &body); 406 407 // Generates verify statements for successors in the operation. 408 // The generated code will be attached to `body`. 409 void genSuccessorVerifier(MethodBody &body); 410 411 // Generates the traits used by the object. 412 void genTraits(); 413 414 // Generate the OpInterface methods for all interfaces. 415 void genOpInterfaceMethods(); 416 417 // Generate op interface methods for the given interface. 418 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait); 419 420 // Generate op interface method for the given interface method. If 421 // 'declaration' is true, generates a declaration, else a definition. 422 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, 423 bool declaration = true); 424 425 // Generate the side effect interface methods. 426 void genSideEffectInterfaceMethods(); 427 428 // Generate the type inference interface methods. 429 void genTypeInterfaceMethods(); 430 431 private: 432 // The TableGen record for this op. 433 // TODO: OpEmitter should not have a Record directly, 434 // it should rather go through the Operator for better abstraction. 435 const Record &def; 436 437 // The wrapper operator class for querying information from this op. 438 Operator op; 439 440 // The C++ code builder for this op 441 OpClass opClass; 442 443 // The format context for verification code generation. 444 FmtContext verifyCtx; 445 446 // The emitter containing all of the locally emitted verification functions. 447 const StaticVerifierFunctionEmitter &staticVerifierEmitter; 448 }; 449 450 } // namespace 451 452 // Populate the format context `ctx` with substitutions of attributes, operands 453 // and results. 454 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper, 455 FmtContext &ctx) { 456 // Populate substitutions for attributes. 457 auto &op = emitHelper.getOp(); 458 for (const auto &namedAttr : op.getAttributes()) 459 ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str()); 460 461 // Populate substitutions for named operands. 462 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 463 auto &value = op.getOperand(i); 464 if (!value.name.empty()) 465 ctx.addSubst(value.name, emitHelper.getOperand(i).str()); 466 } 467 468 // Populate substitutions for results. 469 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 470 auto &value = op.getResult(i); 471 if (!value.name.empty()) 472 ctx.addSubst(value.name, emitHelper.getResult(i).str()); 473 } 474 } 475 476 // Generate attribute verification. If an op instance is not available, then 477 // attribute checks that require one will not be emitted. 478 static void genAttributeVerifier( 479 const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body, 480 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 481 // Check that a required attribute exists. 482 // 483 // {0}: Attribute variable name. 484 // {1}: Emit error prefix. 485 // {2}: Attribute name. 486 const char *const verifyRequiredAttr = R"( 487 if (!{0}) 488 return {1}"requires attribute '{2}'"); 489 )"; 490 // Verify the attribute if it is present. This assumes that default values 491 // are valid. This code snippet pastes the condition inline. 492 // 493 // TODO: verify the default value is valid (perhaps in debug mode only). 494 // 495 // {0}: Attribute variable name. 496 // {1}: Attribute condition code. 497 // {2}: Emit error prefix. 498 // {3}: Attribute name. 499 // {4}: Attribute/constraint description. 500 const char *const verifyAttrInline = R"( 501 if ({0} && !({1})) 502 return {2}"attribute '{3}' failed to satisfy constraint: {4}"); 503 )"; 504 // Verify the attribute using a uniqued constraint. Can only be used within 505 // the context of an op. 506 // 507 // {0}: Unique constraint name. 508 // {1}: Attribute variable name. 509 // {2}: Attribute name. 510 const char *const verifyAttrUnique = R"( 511 if (::mlir::failed({0}(*this, {1}, "{2}"))) 512 return ::mlir::failure(); 513 )"; 514 515 for (const auto &namedAttr : emitHelper.getOp().getAttributes()) { 516 const auto &attr = namedAttr.attr; 517 StringRef attrName = namedAttr.name; 518 if (attr.isDerivedAttr()) 519 continue; 520 521 bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); 522 auto attrPred = attr.getPredicate(); 523 std::string condition = attrPred.isNull() ? "" : attrPred.getCondition(); 524 // If the attribute's condition needs an op but none is available, then the 525 // condition cannot be emitted. 526 bool canEmitCondition = 527 !condition.empty() && (!StringRef(condition).contains("$_op") || 528 emitHelper.isEmittingForOp()); 529 530 // Prefix with `tblgen_` to avoid hiding the attribute accessor. 531 std::string varName = (tblgenNamePrefix + attrName).str(); 532 533 // If the attribute is not required and we cannot emit the condition, then 534 // there is nothing to be done. 535 if (allowMissingAttr && !canEmitCondition) 536 continue; 537 538 body << formatv(" {\n auto {0} = {1};", varName, 539 emitHelper.getAttr(attrName)); 540 541 if (!allowMissingAttr) { 542 body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(), 543 attrName); 544 } 545 if (canEmitCondition) { 546 Optional<StringRef> constraintFn; 547 if (emitHelper.isEmittingForOp() && 548 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { 549 body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); 550 } else { 551 body << formatv(verifyAttrInline, varName, 552 tgfmt(condition, &ctx.withSelf(varName)), 553 emitHelper.emitErrorPrefix(), attrName, 554 escapeString(attr.getSummary())); 555 } 556 } 557 body << " }\n"; 558 } 559 } 560 561 /// Op extra class definitions have a `$cppClass` substitution that is to be 562 /// replaced by the C++ class name. 563 static std::string formatExtraDefinitions(const Operator &op) { 564 FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); 565 return tgfmt(op.getExtraClassDefinition(), &ctx).str(); 566 } 567 568 OpEmitter::OpEmitter(const Operator &op, 569 const StaticVerifierFunctionEmitter &staticVerifierEmitter) 570 : def(op.getDef()), op(op), 571 opClass(op.getCppClassName(), op.getExtraClassDeclaration(), 572 formatExtraDefinitions(op)), 573 staticVerifierEmitter(staticVerifierEmitter) { 574 verifyCtx.withOp("(*this->getOperation())"); 575 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); 576 577 genTraits(); 578 579 // Generate C++ code for various op methods. The order here determines the 580 // methods in the generated file. 581 genAttrNameGetters(); 582 genOpAsmInterface(); 583 genOpNameGetter(); 584 genNamedOperandGetters(); 585 genNamedOperandSetters(); 586 genNamedResultGetters(); 587 genNamedRegionGetters(); 588 genNamedSuccessorGetters(); 589 genAttrGetters(); 590 genAttrSetters(); 591 genOptionalAttrRemovers(); 592 genBuilder(); 593 genParser(); 594 genPrinter(); 595 genVerifier(); 596 genCanonicalizerDecls(); 597 genFolderDecls(); 598 genTypeInterfaceMethods(); 599 genOpInterfaceMethods(); 600 generateOpFormat(op, opClass); 601 genSideEffectInterfaceMethods(); 602 } 603 void OpEmitter::emitDecl( 604 const Operator &op, raw_ostream &os, 605 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 606 OpEmitter(op, staticVerifierEmitter).emitDecl(os); 607 } 608 609 void OpEmitter::emitDef( 610 const Operator &op, raw_ostream &os, 611 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 612 OpEmitter(op, staticVerifierEmitter).emitDef(os); 613 } 614 615 void OpEmitter::emitDecl(raw_ostream &os) { 616 opClass.finalize(); 617 opClass.writeDeclTo(os); 618 } 619 620 void OpEmitter::emitDef(raw_ostream &os) { 621 opClass.finalize(); 622 opClass.writeDefTo(os); 623 } 624 625 static void errorIfPruned(size_t line, Method *m, const Twine &methodName, 626 const Operator &op) { 627 if (m) 628 return; 629 PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" + 630 methodName + "` for " + 631 op.getOperationName() + " (from line " + 632 Twine(line) + ")"); 633 } 634 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) 635 636 void OpEmitter::genAttrNameGetters() { 637 // A map of attribute names (including implicit attributes) registered to the 638 // current operation, to the relative order in which they were registered. 639 llvm::MapVector<StringRef, unsigned> attributeNames; 640 641 // Enumerate the attribute names of this op, assigning each a relative 642 // ordering. 643 auto addAttrName = [&](StringRef name) { 644 unsigned index = attributeNames.size(); 645 attributeNames.insert({name, index}); 646 }; 647 for (const NamedAttribute &namedAttr : op.getAttributes()) 648 addAttrName(namedAttr.name); 649 // Include key attributes from several traits as implicitly registered. 650 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) 651 addAttrName("operand_segment_sizes"); 652 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) 653 addAttrName("result_segment_sizes"); 654 655 // Emit the getAttributeNames method. 656 { 657 auto *method = opClass.addStaticInlineMethod( 658 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); 659 ERROR_IF_PRUNED(method, "getAttributeNames", op); 660 auto &body = method->body(); 661 if (attributeNames.empty()) { 662 body << " return {};"; 663 } else { 664 body << " static ::llvm::StringRef attrNames[] = {"; 665 llvm::interleaveComma(llvm::make_first_range(attributeNames), body, 666 [&](StringRef attrName) { 667 body << "::llvm::StringRef(\"" << attrName 668 << "\")"; 669 }); 670 body << "};\n return ::llvm::makeArrayRef(attrNames);"; 671 } 672 } 673 if (attributeNames.empty()) 674 return; 675 676 // Emit the getAttributeNameForIndex methods. 677 { 678 auto *method = opClass.addInlineMethod<Method::Private>( 679 "::mlir::StringAttr", "getAttributeNameForIndex", 680 MethodParameter("unsigned", "index")); 681 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); 682 method->body() 683 << " return getAttributeNameForIndex((*this)->getName(), index);"; 684 } 685 { 686 auto *method = opClass.addStaticInlineMethod<Method::Private>( 687 "::mlir::StringAttr", "getAttributeNameForIndex", 688 MethodParameter("::mlir::OperationName", "name"), 689 MethodParameter("unsigned", "index")); 690 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); 691 692 const char *const getAttrName = R"( 693 assert(index < {0} && "invalid attribute index"); 694 return name.getRegisteredInfo()->getAttributeNames()[index]; 695 )"; 696 method->body() << formatv(getAttrName, attributeNames.size()); 697 } 698 699 // Generate the <attr>AttrName methods, that expose the attribute names to 700 // users. 701 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; 702 for (const std::pair<StringRef, unsigned> &attrIt : attributeNames) { 703 for (StringRef name : op.getGetterNames(attrIt.first)) { 704 std::string methodName = (name + "AttrName").str(); 705 706 // Generate the non-static variant. 707 { 708 auto *method = 709 opClass.addInlineMethod("::mlir::StringAttr", methodName); 710 ERROR_IF_PRUNED(method, methodName, op); 711 method->body() << llvm::formatv(attrNameMethodBody, attrIt.second); 712 } 713 714 // Generate the static variant. 715 { 716 auto *method = opClass.addStaticInlineMethod( 717 "::mlir::StringAttr", methodName, 718 MethodParameter("::mlir::OperationName", "name")); 719 ERROR_IF_PRUNED(method, methodName, op); 720 method->body() << llvm::formatv(attrNameMethodBody, 721 "name, " + Twine(attrIt.second)); 722 } 723 } 724 } 725 } 726 727 // Emit the getter for an attribute with the return type specified. 728 // It is templated to be shared between the Op and the adaptor class. 729 template <typename OpClassOrAdaptor> 730 static void emitAttrGetterWithReturnType(FmtContext &fctx, 731 OpClassOrAdaptor &opClass, 732 const Operator &op, StringRef name, 733 Attribute attr) { 734 auto *method = opClass.addMethod(attr.getReturnType(), name); 735 ERROR_IF_PRUNED(method, name, op); 736 auto &body = method->body(); 737 body << " auto attr = " << name << "Attr();\n"; 738 if (attr.hasDefaultValue()) { 739 // Returns the default value if not set. 740 // TODO: this is inefficient, we are recreating the attribute for every 741 // call. This should be set instead. 742 std::string defaultValue = std::string( 743 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 744 body << " if (!attr)\n return " 745 << tgfmt(attr.getConvertFromStorageCall(), 746 &fctx.withSelf(defaultValue)) 747 << ";\n"; 748 } 749 body << " return " 750 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) 751 << ";\n"; 752 } 753 754 void OpEmitter::genAttrGetters() { 755 FmtContext fctx; 756 fctx.withBuilder("::mlir::Builder((*this)->getContext())"); 757 758 // Emit the derived attribute body. 759 auto emitDerivedAttr = [&](StringRef name, Attribute attr) { 760 if (auto *method = opClass.addMethod(attr.getReturnType(), name)) 761 method->body() << " " << attr.getDerivedCodeBody() << "\n"; 762 }; 763 764 // Generate named accessor with Attribute return type. This is a wrapper class 765 // that allows referring to the attributes via accessors instead of having to 766 // use the string interface for better compile time verification. 767 auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { 768 auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr"); 769 if (!method) 770 return; 771 method->body() << formatv( 772 " return {0}.{1}<{2}>();", formatv(opGetAttr, name), 773 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" 774 : "cast", 775 attr.getStorageType()); 776 }; 777 778 for (const NamedAttribute &namedAttr : op.getAttributes()) { 779 for (StringRef name : op.getGetterNames(namedAttr.name)) { 780 if (namedAttr.attr.isDerivedAttr()) { 781 emitDerivedAttr(name, namedAttr.attr); 782 } else { 783 emitAttrWithStorageType(name, namedAttr.attr); 784 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); 785 } 786 } 787 } 788 789 auto derivedAttrs = make_filter_range(op.getAttributes(), 790 [](const NamedAttribute &namedAttr) { 791 return namedAttr.attr.isDerivedAttr(); 792 }); 793 if (derivedAttrs.empty()) 794 return; 795 796 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); 797 // Generate helper method to query whether a named attribute is a derived 798 // attribute. This enables, for example, avoiding adding an attribute that 799 // overlaps with a derived attribute. 800 { 801 auto *method = 802 opClass.addStaticMethod("bool", "isDerivedAttribute", 803 MethodParameter("::llvm::StringRef", "name")); 804 ERROR_IF_PRUNED(method, "isDerivedAttribute", op); 805 auto &body = method->body(); 806 for (auto namedAttr : derivedAttrs) 807 body << " if (name == \"" << namedAttr.name << "\") return true;\n"; 808 body << " return false;"; 809 } 810 // Generate method to materialize derived attributes as a DictionaryAttr. 811 { 812 auto *method = opClass.addMethod("::mlir::DictionaryAttr", 813 "materializeDerivedAttributes"); 814 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); 815 auto &body = method->body(); 816 817 auto nonMaterializable = 818 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { 819 return namedAttr.attr.getConvertFromStorageCall().empty(); 820 }); 821 if (!nonMaterializable.empty()) { 822 std::string attrs; 823 llvm::raw_string_ostream os(attrs); 824 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { 825 os << op.getGetterName(attr.name); 826 }); 827 PrintWarning( 828 op.getLoc(), 829 formatv( 830 "op has non-materializable derived attributes '{0}', skipping", 831 os.str())); 832 body << formatv(" emitOpError(\"op has non-materializable derived " 833 "attributes '{0}'\");\n", 834 attrs); 835 body << " return nullptr;"; 836 return; 837 } 838 839 body << " ::mlir::MLIRContext* ctx = getContext();\n"; 840 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; 841 body << " return ::mlir::DictionaryAttr::get("; 842 body << " ctx, {\n"; 843 interleave( 844 derivedAttrs, body, 845 [&](const NamedAttribute &namedAttr) { 846 auto tmpl = namedAttr.attr.getConvertFromStorageCall(); 847 std::string name = op.getGetterName(namedAttr.name); 848 body << " {" << name << "AttrName(),\n" 849 << tgfmt(tmpl, &fctx.withSelf(name + "()") 850 .withBuilder("odsBuilder") 851 .addSubst("_ctx", "ctx")) 852 << "}"; 853 }, 854 ",\n"); 855 body << "});"; 856 } 857 } 858 859 void OpEmitter::genAttrSetters() { 860 // Generate raw named setter type. This is a wrapper class that allows setting 861 // to the attributes via setters instead of having to use the string interface 862 // for better compile time verification. 863 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName, 864 Attribute attr) { 865 auto *method = 866 opClass.addMethod("void", setterName + "Attr", 867 MethodParameter(attr.getStorageType(), "attr")); 868 if (method) 869 method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);", 870 getterName); 871 }; 872 873 for (const NamedAttribute &namedAttr : op.getAttributes()) { 874 if (namedAttr.attr.isDerivedAttr()) 875 continue; 876 for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), 877 op.getGetterNames(namedAttr.name))) 878 emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), 879 namedAttr.attr); 880 } 881 } 882 883 void OpEmitter::genOptionalAttrRemovers() { 884 // Generate methods for removing optional attributes, instead of having to 885 // use the string interface. Enables better compile time verification. 886 auto emitRemoveAttr = [&](StringRef name) { 887 auto upperInitial = name.take_front().upper(); 888 auto suffix = name.drop_front(); 889 auto *method = opClass.addMethod("::mlir::Attribute", 890 "remove" + upperInitial + suffix + "Attr"); 891 if (!method) 892 return; 893 method->body() << formatv(" return (*this)->removeAttr({0}AttrName());", 894 op.getGetterName(name)); 895 }; 896 897 for (const NamedAttribute &namedAttr : op.getAttributes()) 898 if (namedAttr.attr.isOptional()) 899 emitRemoveAttr(namedAttr.name); 900 } 901 902 // Generates the code to compute the start and end index of an operand or result 903 // range. 904 template <typename RangeT> 905 static void 906 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, 907 int numVariadic, int numNonVariadic, 908 StringRef rangeSizeCall, bool hasAttrSegmentSize, 909 StringRef sizeAttrInit, RangeT &&odsValues) { 910 auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName, 911 MethodParameter("unsigned", "index")); 912 if (!method) 913 return; 914 auto &body = method->body(); 915 if (numVariadic == 0) { 916 body << " return {index, 1};\n"; 917 } else if (hasAttrSegmentSize) { 918 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; 919 } else { 920 // Because the op can have arbitrarily interleaved variadic and non-variadic 921 // operands, we need to embed a list in the "sink" getter method for 922 // calculation at run-time. 923 SmallVector<StringRef, 4> isVariadic; 924 isVariadic.reserve(llvm::size(odsValues)); 925 for (auto &it : odsValues) 926 isVariadic.push_back(it.isVariableLength() ? "true" : "false"); 927 std::string isVariadicList = llvm::join(isVariadic, ", "); 928 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, 929 numNonVariadic, numVariadic, rangeSizeCall, "operand"); 930 } 931 } 932 933 // Generates the named operand getter methods for the given Operator `op` and 934 // puts them in `opClass`. Uses `rangeType` as the return type of getters that 935 // return a range of operands (individual operands are `Value ` and each 936 // element in the range must also be `Value `); use `rangeBeginCall` to get 937 // an iterator to the beginning of the operand range; use `rangeSizeCall` to 938 // obtain the number of operands. `getOperandCallPattern` contains the code 939 // necessary to obtain a single operand whose position will be substituted 940 // instead of 941 // "{0}" marker in the pattern. Note that the pattern should work for any kind 942 // of ops, in particular for one-operand ops that may not have the 943 // `getOperand(unsigned)` method. 944 static void generateNamedOperandGetters(const Operator &op, Class &opClass, 945 bool isAdaptor, StringRef sizeAttrInit, 946 StringRef rangeType, 947 StringRef rangeBeginCall, 948 StringRef rangeSizeCall, 949 StringRef getOperandCallPattern) { 950 const int numOperands = op.getNumOperands(); 951 const int numVariadicOperands = op.getNumVariableLengthOperands(); 952 const int numNormalOperands = numOperands - numVariadicOperands; 953 954 const auto *sameVariadicSize = 955 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); 956 const auto *attrSizedOperands = 957 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 958 959 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { 960 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " 961 "specification over their sizes"); 962 } 963 964 if (numVariadicOperands < 2 && attrSizedOperands) { 965 PrintFatalError(op.getLoc(), "op must have at least two variadic operands " 966 "to use 'AttrSizedOperandSegments' trait"); 967 } 968 969 if (attrSizedOperands && sameVariadicSize) { 970 PrintFatalError(op.getLoc(), 971 "op cannot have both 'AttrSizedOperandSegments' and " 972 "'SameVariadicOperandSize' traits"); 973 } 974 975 // First emit a few "sink" getter methods upon which we layer all nicer named 976 // getter methods. 977 generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength", 978 numVariadicOperands, numNormalOperands, 979 rangeSizeCall, attrSizedOperands, sizeAttrInit, 980 const_cast<Operator &>(op).getOperands()); 981 982 auto *m = opClass.addMethod(rangeType, "getODSOperands", 983 MethodParameter("unsigned", "index")); 984 ERROR_IF_PRUNED(m, "getODSOperands", op); 985 auto &body = m->body(); 986 body << formatv(valueRangeReturnCode, rangeBeginCall, 987 "getODSOperandIndexAndLength(index)"); 988 989 // Then we emit nicer named getter methods by redirecting to the "sink" getter 990 // method. 991 for (int i = 0; i != numOperands; ++i) { 992 const auto &operand = op.getOperand(i); 993 if (operand.name.empty()) 994 continue; 995 for (StringRef name : op.getGetterNames(operand.name)) { 996 if (operand.isOptional()) { 997 m = opClass.addMethod("::mlir::Value", name); 998 ERROR_IF_PRUNED(m, name, op); 999 m->body() << " auto operands = getODSOperands(" << i << ");\n" 1000 << " return operands.empty() ? ::mlir::Value() : " 1001 "*operands.begin();"; 1002 } else if (operand.isVariadicOfVariadic()) { 1003 std::string segmentAttr = op.getGetterName( 1004 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); 1005 if (isAdaptor) { 1006 m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", 1007 name); 1008 ERROR_IF_PRUNED(m, name, op); 1009 m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, 1010 segmentAttr, i); 1011 continue; 1012 } 1013 1014 m = opClass.addMethod("::mlir::OperandRangeRange", name); 1015 ERROR_IF_PRUNED(m, name, op); 1016 m->body() << " return getODSOperands(" << i << ").split(" 1017 << segmentAttr << "Attr());"; 1018 } else if (operand.isVariadic()) { 1019 m = opClass.addMethod(rangeType, name); 1020 ERROR_IF_PRUNED(m, name, op); 1021 m->body() << " return getODSOperands(" << i << ");"; 1022 } else { 1023 m = opClass.addMethod("::mlir::Value", name); 1024 ERROR_IF_PRUNED(m, name, op); 1025 m->body() << " return *getODSOperands(" << i << ").begin();"; 1026 } 1027 } 1028 } 1029 } 1030 1031 void OpEmitter::genNamedOperandGetters() { 1032 // Build the code snippet used for initializing the operand_segment_size)s 1033 // array. 1034 std::string attrSizeInitCode; 1035 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 1036 std::string attr = op.getGetterName("operand_segment_sizes"); 1037 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); 1038 } 1039 1040 generateNamedOperandGetters( 1041 op, opClass, 1042 /*isAdaptor=*/false, 1043 /*sizeAttrInit=*/attrSizeInitCode, 1044 /*rangeType=*/"::mlir::Operation::operand_range", 1045 /*rangeBeginCall=*/"getOperation()->operand_begin()", 1046 /*rangeSizeCall=*/"getOperation()->getNumOperands()", 1047 /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); 1048 } 1049 1050 void OpEmitter::genNamedOperandSetters() { 1051 auto *attrSizedOperands = 1052 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 1053 for (int i = 0, e = op.getNumOperands(); i != e; ++i) { 1054 const auto &operand = op.getOperand(i); 1055 if (operand.name.empty()) 1056 continue; 1057 for (StringRef name : op.getGetterNames(operand.name)) { 1058 auto *m = opClass.addMethod(operand.isVariadicOfVariadic() 1059 ? "::mlir::MutableOperandRangeRange" 1060 : "::mlir::MutableOperandRange", 1061 (name + "Mutable").str()); 1062 ERROR_IF_PRUNED(m, name, op); 1063 auto &body = m->body(); 1064 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" 1065 << " auto mutableRange = " 1066 "::mlir::MutableOperandRange(getOperation(), " 1067 "range.first, range.second"; 1068 if (attrSizedOperands) 1069 body << ", ::mlir::MutableOperandRange::OperandSegment(" << i 1070 << "u, *getOperation()->getAttrDictionary().getNamed(" 1071 << op.getGetterName("operand_segment_sizes") << "AttrName()))"; 1072 body << ");\n"; 1073 1074 // If this operand is a nested variadic, we split the range into a 1075 // MutableOperandRangeRange that provides a range over all of the 1076 // sub-ranges. 1077 if (operand.isVariadicOfVariadic()) { 1078 // 1079 body << " return " 1080 "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" 1081 << op.getGetterName( 1082 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) 1083 << "AttrName()));\n"; 1084 } else { 1085 // Otherwise, we use the full range directly. 1086 body << " return mutableRange;\n"; 1087 } 1088 } 1089 } 1090 } 1091 1092 void OpEmitter::genNamedResultGetters() { 1093 const int numResults = op.getNumResults(); 1094 const int numVariadicResults = op.getNumVariableLengthResults(); 1095 const int numNormalResults = numResults - numVariadicResults; 1096 1097 // If we have more than one variadic results, we need more complicated logic 1098 // to calculate the value range for each result. 1099 1100 const auto *sameVariadicSize = 1101 op.getTrait("::mlir::OpTrait::SameVariadicResultSize"); 1102 const auto *attrSizedResults = 1103 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"); 1104 1105 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { 1106 PrintFatalError(op.getLoc(), "op has multiple variadic results but no " 1107 "specification over their sizes"); 1108 } 1109 1110 if (numVariadicResults < 2 && attrSizedResults) { 1111 PrintFatalError(op.getLoc(), "op must have at least two variadic results " 1112 "to use 'AttrSizedResultSegments' trait"); 1113 } 1114 1115 if (attrSizedResults && sameVariadicSize) { 1116 PrintFatalError(op.getLoc(), 1117 "op cannot have both 'AttrSizedResultSegments' and " 1118 "'SameVariadicResultSize' traits"); 1119 } 1120 1121 // Build the initializer string for the result segment size attribute. 1122 std::string attrSizeInitCode; 1123 if (attrSizedResults) { 1124 std::string attr = op.getGetterName("result_segment_sizes"); 1125 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); 1126 } 1127 1128 generateValueRangeStartAndEnd( 1129 opClass, "getODSResultIndexAndLength", numVariadicResults, 1130 numNormalResults, "getOperation()->getNumResults()", attrSizedResults, 1131 attrSizeInitCode, op.getResults()); 1132 1133 auto *m = 1134 opClass.addMethod("::mlir::Operation::result_range", "getODSResults", 1135 MethodParameter("unsigned", "index")); 1136 ERROR_IF_PRUNED(m, "getODSResults", op); 1137 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", 1138 "getODSResultIndexAndLength(index)"); 1139 1140 for (int i = 0; i != numResults; ++i) { 1141 const auto &result = op.getResult(i); 1142 if (result.name.empty()) 1143 continue; 1144 for (StringRef name : op.getGetterNames(result.name)) { 1145 if (result.isOptional()) { 1146 m = opClass.addMethod("::mlir::Value", name); 1147 ERROR_IF_PRUNED(m, name, op); 1148 m->body() 1149 << " auto results = getODSResults(" << i << ");\n" 1150 << " return results.empty() ? ::mlir::Value() : *results.begin();"; 1151 } else if (result.isVariadic()) { 1152 m = opClass.addMethod("::mlir::Operation::result_range", name); 1153 ERROR_IF_PRUNED(m, name, op); 1154 m->body() << " return getODSResults(" << i << ");"; 1155 } else { 1156 m = opClass.addMethod("::mlir::Value", name); 1157 ERROR_IF_PRUNED(m, name, op); 1158 m->body() << " return *getODSResults(" << i << ").begin();"; 1159 } 1160 } 1161 } 1162 } 1163 1164 void OpEmitter::genNamedRegionGetters() { 1165 unsigned numRegions = op.getNumRegions(); 1166 for (unsigned i = 0; i < numRegions; ++i) { 1167 const auto ®ion = op.getRegion(i); 1168 if (region.name.empty()) 1169 continue; 1170 1171 for (StringRef name : op.getGetterNames(region.name)) { 1172 // Generate the accessors for a variadic region. 1173 if (region.isVariadic()) { 1174 auto *m = 1175 opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name); 1176 ERROR_IF_PRUNED(m, name, op); 1177 m->body() << formatv(" return (*this)->getRegions().drop_front({0});", 1178 i); 1179 continue; 1180 } 1181 1182 auto *m = opClass.addMethod("::mlir::Region &", name); 1183 ERROR_IF_PRUNED(m, name, op); 1184 m->body() << formatv(" return (*this)->getRegion({0});", i); 1185 } 1186 } 1187 } 1188 1189 void OpEmitter::genNamedSuccessorGetters() { 1190 unsigned numSuccessors = op.getNumSuccessors(); 1191 for (unsigned i = 0; i < numSuccessors; ++i) { 1192 const NamedSuccessor &successor = op.getSuccessor(i); 1193 if (successor.name.empty()) 1194 continue; 1195 1196 for (StringRef name : op.getGetterNames(successor.name)) { 1197 // Generate the accessors for a variadic successor list. 1198 if (successor.isVariadic()) { 1199 auto *m = opClass.addMethod("::mlir::SuccessorRange", name); 1200 ERROR_IF_PRUNED(m, name, op); 1201 m->body() << formatv( 1202 " return {std::next((*this)->successor_begin(), {0}), " 1203 "(*this)->successor_end()};", 1204 i); 1205 continue; 1206 } 1207 1208 auto *m = opClass.addMethod("::mlir::Block *", name); 1209 ERROR_IF_PRUNED(m, name, op); 1210 m->body() << formatv(" return (*this)->getSuccessor({0});", i); 1211 } 1212 } 1213 } 1214 1215 static bool canGenerateUnwrappedBuilder(Operator &op) { 1216 // If this op does not have native attributes at all, return directly to avoid 1217 // redefining builders. 1218 if (op.getNumNativeAttributes() == 0) 1219 return false; 1220 1221 bool canGenerate = false; 1222 // We are generating builders that take raw values for attributes. We need to 1223 // make sure the native attributes have a meaningful "unwrapped" value type 1224 // different from the wrapped mlir::Attribute type to avoid redefining 1225 // builders. This checks for the op has at least one such native attribute. 1226 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { 1227 NamedAttribute &namedAttr = op.getAttribute(i); 1228 if (canUseUnwrappedRawValue(namedAttr.attr)) { 1229 canGenerate = true; 1230 break; 1231 } 1232 } 1233 return canGenerate; 1234 } 1235 1236 static bool canInferType(Operator &op) { 1237 return op.getTrait("::mlir::InferTypeOpInterface::Trait"); 1238 } 1239 1240 void OpEmitter::genSeparateArgParamBuilder() { 1241 SmallVector<AttrParamKind, 2> attrBuilderType; 1242 attrBuilderType.push_back(AttrParamKind::WrappedAttr); 1243 if (canGenerateUnwrappedBuilder(op)) 1244 attrBuilderType.push_back(AttrParamKind::UnwrappedValue); 1245 1246 // Emit with separate builders with or without unwrapped attributes and/or 1247 // inferring result type. 1248 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, 1249 bool inferType) { 1250 SmallVector<MethodParameter> paramList; 1251 SmallVector<std::string, 4> resultNames; 1252 llvm::StringSet<> inferredAttributes; 1253 buildParamList(paramList, inferredAttributes, resultNames, paramKind, 1254 attrType); 1255 1256 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1257 // If the builder is redundant, skip generating the method. 1258 if (!m) 1259 return; 1260 auto &body = m->body(); 1261 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, 1262 /*isRawValueAttr=*/attrType == 1263 AttrParamKind::UnwrappedValue); 1264 1265 // Push all result types to the operation state 1266 1267 if (inferType) { 1268 // Generate builder that infers type too. 1269 // TODO: Subsume this with general checking if type can be 1270 // inferred automatically. 1271 // TODO: Expand to handle regions. 1272 body << formatv(R"( 1273 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; 1274 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), 1275 {1}.location, {1}.operands, 1276 {1}.attributes.getDictionary({1}.getContext()), 1277 /*regions=*/{{}, inferredReturnTypes))) 1278 {1}.addTypes(inferredReturnTypes); 1279 else 1280 ::llvm::report_fatal_error("Failed to infer result type(s).");)", 1281 opClass.getClassName(), builderOpState); 1282 return; 1283 } 1284 1285 switch (paramKind) { 1286 case TypeParamKind::None: 1287 return; 1288 case TypeParamKind::Separate: 1289 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 1290 if (op.getResult(i).isOptional()) 1291 body << " if (" << resultNames[i] << ")\n "; 1292 body << " " << builderOpState << ".addTypes(" << resultNames[i] 1293 << ");\n"; 1294 } 1295 return; 1296 case TypeParamKind::Collective: { 1297 int numResults = op.getNumResults(); 1298 int numVariadicResults = op.getNumVariableLengthResults(); 1299 int numNonVariadicResults = numResults - numVariadicResults; 1300 bool hasVariadicResult = numVariadicResults != 0; 1301 1302 // Avoid emitting "resultTypes.size() >= 0u" which is always true. 1303 if (!(hasVariadicResult && numNonVariadicResults == 0)) 1304 body << " " 1305 << "assert(resultTypes.size() " 1306 << (hasVariadicResult ? ">=" : "==") << " " 1307 << numNonVariadicResults 1308 << "u && \"mismatched number of results\");\n"; 1309 body << " " << builderOpState << ".addTypes(resultTypes);\n"; 1310 } 1311 return; 1312 } 1313 llvm_unreachable("unhandled TypeParamKind"); 1314 }; 1315 1316 // Some of the build methods generated here may be ambiguous, but TableGen's 1317 // ambiguous function detection will elide those ones. 1318 for (auto attrType : attrBuilderType) { 1319 emit(attrType, TypeParamKind::Separate, /*inferType=*/false); 1320 if (canInferType(op) && op.getNumRegions() == 0) 1321 emit(attrType, TypeParamKind::None, /*inferType=*/true); 1322 emit(attrType, TypeParamKind::Collective, /*inferType=*/false); 1323 } 1324 } 1325 1326 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { 1327 int numResults = op.getNumResults(); 1328 1329 // Signature 1330 SmallVector<MethodParameter> paramList; 1331 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1332 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1333 paramList.emplace_back("::mlir::ValueRange", "operands"); 1334 // Provide default value for `attributes` when its the last parameter 1335 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 1336 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1337 "attributes", attributesDefaultValue); 1338 if (op.getNumVariadicRegions()) 1339 paramList.emplace_back("unsigned", "numRegions"); 1340 1341 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1342 // If the builder is redundant, skip generating the method 1343 if (!m) 1344 return; 1345 auto &body = m->body(); 1346 1347 // Operands 1348 body << " " << builderOpState << ".addOperands(operands);\n"; 1349 1350 // Attributes 1351 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1352 1353 // Create the correct number of regions 1354 if (int numRegions = op.getNumRegions()) { 1355 body << llvm::formatv( 1356 " for (unsigned i = 0; i != {0}; ++i)\n", 1357 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 1358 body << " (void)" << builderOpState << ".addRegion();\n"; 1359 } 1360 1361 // Result types 1362 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()"); 1363 body << " " << builderOpState << ".addTypes({" 1364 << llvm::join(resultTypes, ", ") << "});\n\n"; 1365 } 1366 1367 void OpEmitter::genInferredTypeCollectiveParamBuilder() { 1368 // TODO: Expand to support regions. 1369 SmallVector<MethodParameter> paramList; 1370 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1371 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1372 paramList.emplace_back("::mlir::ValueRange", "operands"); 1373 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1374 "attributes", "{}"); 1375 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1376 // If the builder is redundant, skip generating the method 1377 if (!m) 1378 return; 1379 auto &body = m->body(); 1380 1381 int numResults = op.getNumResults(); 1382 int numVariadicResults = op.getNumVariableLengthResults(); 1383 int numNonVariadicResults = numResults - numVariadicResults; 1384 1385 int numOperands = op.getNumOperands(); 1386 int numVariadicOperands = op.getNumVariableLengthOperands(); 1387 int numNonVariadicOperands = numOperands - numVariadicOperands; 1388 1389 // Operands 1390 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 1391 body << " assert(operands.size()" 1392 << (numVariadicOperands != 0 ? " >= " : " == ") 1393 << numNonVariadicOperands 1394 << "u && \"mismatched number of parameters\");\n"; 1395 body << " " << builderOpState << ".addOperands(operands);\n"; 1396 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1397 1398 // Create the correct number of regions 1399 if (int numRegions = op.getNumRegions()) { 1400 body << llvm::formatv( 1401 " for (unsigned i = 0; i != {0}; ++i)\n", 1402 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 1403 body << " (void)" << builderOpState << ".addRegion();\n"; 1404 } 1405 1406 // Result types 1407 body << formatv(R"( 1408 ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; 1409 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), 1410 {1}.location, operands, 1411 {1}.attributes.getDictionary({1}.getContext()), 1412 {1}.regions, inferredReturnTypes))) {{)", 1413 opClass.getClassName(), builderOpState); 1414 if (numVariadicResults == 0 || numNonVariadicResults != 0) 1415 body << "\n assert(inferredReturnTypes.size()" 1416 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 1417 << "u && \"mismatched number of return types\");"; 1418 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);"; 1419 1420 body << formatv(R"( 1421 } else {{ 1422 ::llvm::report_fatal_error("Failed to infer result type(s)."); 1423 })", 1424 opClass.getClassName(), builderOpState); 1425 } 1426 1427 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { 1428 SmallVector<MethodParameter> paramList; 1429 SmallVector<std::string, 4> resultNames; 1430 llvm::StringSet<> inferredAttributes; 1431 buildParamList(paramList, inferredAttributes, resultNames, 1432 TypeParamKind::None); 1433 1434 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1435 // If the builder is redundant, skip generating the method 1436 if (!m) 1437 return; 1438 auto &body = m->body(); 1439 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes); 1440 1441 auto numResults = op.getNumResults(); 1442 if (numResults == 0) 1443 return; 1444 1445 // Push all result types to the operation state 1446 const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; 1447 std::string resultType = 1448 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); 1449 body << " " << builderOpState << ".addTypes({" << resultType; 1450 for (int i = 1; i != numResults; ++i) 1451 body << ", " << resultType; 1452 body << "});\n\n"; 1453 } 1454 1455 void OpEmitter::genUseAttrAsResultTypeBuilder() { 1456 SmallVector<MethodParameter> paramList; 1457 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1458 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1459 paramList.emplace_back("::mlir::ValueRange", "operands"); 1460 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1461 "attributes", "{}"); 1462 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1463 // If the builder is redundant, skip generating the method 1464 if (!m) 1465 return; 1466 1467 auto &body = m->body(); 1468 1469 // Push all result types to the operation state 1470 std::string resultType; 1471 const auto &namedAttr = op.getAttribute(0); 1472 1473 body << " auto attrName = " << op.getGetterName(namedAttr.name) 1474 << "AttrName(" << builderOpState 1475 << ".name);\n" 1476 " for (auto attr : attributes) {\n" 1477 " if (attr.getName() != attrName) continue;\n"; 1478 if (namedAttr.attr.isTypeAttr()) { 1479 resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()"; 1480 } else { 1481 resultType = "attr.getValue().getType()"; 1482 } 1483 1484 // Operands 1485 body << " " << builderOpState << ".addOperands(operands);\n"; 1486 1487 // Attributes 1488 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1489 1490 // Result types 1491 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); 1492 body << " " << builderOpState << ".addTypes({" 1493 << llvm::join(resultTypes, ", ") << "});\n"; 1494 body << " }\n"; 1495 } 1496 1497 /// Returns a signature of the builder. Updates the context `fctx` to enable 1498 /// replacement of $_builder and $_state in the body. 1499 static SmallVector<MethodParameter> 1500 getBuilderSignature(const Builder &builder) { 1501 ArrayRef<Builder::Parameter> params(builder.getParameters()); 1502 1503 // Inject builder and state arguments. 1504 SmallVector<MethodParameter> arguments; 1505 arguments.reserve(params.size() + 2); 1506 arguments.emplace_back("::mlir::OpBuilder &", odsBuilder); 1507 arguments.emplace_back("::mlir::OperationState &", builderOpState); 1508 1509 for (unsigned i = 0, e = params.size(); i < e; ++i) { 1510 // If no name is provided, generate one. 1511 Optional<StringRef> paramName = params[i].getName(); 1512 std::string name = 1513 paramName ? paramName->str() : "odsArg" + std::to_string(i); 1514 1515 StringRef defaultValue; 1516 if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue()) 1517 defaultValue = *defaultParamValue; 1518 1519 arguments.emplace_back(params[i].getCppType(), std::move(name), 1520 defaultValue); 1521 } 1522 1523 return arguments; 1524 } 1525 1526 void OpEmitter::genBuilder() { 1527 // Handle custom builders if provided. 1528 for (const Builder &builder : op.getBuilders()) { 1529 SmallVector<MethodParameter> arguments = getBuilderSignature(builder); 1530 1531 Optional<StringRef> body = builder.getBody(); 1532 auto properties = body ? Method::Static : Method::StaticDeclaration; 1533 auto *method = 1534 opClass.addMethod("void", "build", properties, std::move(arguments)); 1535 if (body) 1536 ERROR_IF_PRUNED(method, "build", op); 1537 1538 FmtContext fctx; 1539 fctx.withBuilder(odsBuilder); 1540 fctx.addSubst("_state", builderOpState); 1541 if (body) 1542 method->body() << tgfmt(*body, &fctx); 1543 } 1544 1545 // Generate default builders that requires all result type, operands, and 1546 // attributes as parameters. 1547 if (op.skipDefaultBuilders()) 1548 return; 1549 1550 // We generate three classes of builders here: 1551 // 1. one having a stand-alone parameter for each operand / attribute, and 1552 genSeparateArgParamBuilder(); 1553 // 2. one having an aggregated parameter for all result types / operands / 1554 // attributes, and 1555 genCollectiveParamBuilder(); 1556 // 3. one having a stand-alone parameter for each operand and attribute, 1557 // use the first operand or attribute's type as all result types 1558 // to facilitate different call patterns. 1559 if (op.getNumVariableLengthResults() == 0) { 1560 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 1561 genUseOperandAsResultTypeSeparateParamBuilder(); 1562 genUseOperandAsResultTypeCollectiveParamBuilder(); 1563 } 1564 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) 1565 genUseAttrAsResultTypeBuilder(); 1566 } 1567 } 1568 1569 void OpEmitter::genCollectiveParamBuilder() { 1570 int numResults = op.getNumResults(); 1571 int numVariadicResults = op.getNumVariableLengthResults(); 1572 int numNonVariadicResults = numResults - numVariadicResults; 1573 1574 int numOperands = op.getNumOperands(); 1575 int numVariadicOperands = op.getNumVariableLengthOperands(); 1576 int numNonVariadicOperands = numOperands - numVariadicOperands; 1577 1578 SmallVector<MethodParameter> paramList; 1579 paramList.emplace_back("::mlir::OpBuilder &", ""); 1580 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1581 paramList.emplace_back("::mlir::TypeRange", "resultTypes"); 1582 paramList.emplace_back("::mlir::ValueRange", "operands"); 1583 // Provide default value for `attributes` when its the last parameter 1584 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 1585 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1586 "attributes", attributesDefaultValue); 1587 if (op.getNumVariadicRegions()) 1588 paramList.emplace_back("unsigned", "numRegions"); 1589 1590 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1591 // If the builder is redundant, skip generating the method 1592 if (!m) 1593 return; 1594 auto &body = m->body(); 1595 1596 // Operands 1597 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 1598 body << " assert(operands.size()" 1599 << (numVariadicOperands != 0 ? " >= " : " == ") 1600 << numNonVariadicOperands 1601 << "u && \"mismatched number of parameters\");\n"; 1602 body << " " << builderOpState << ".addOperands(operands);\n"; 1603 1604 // Attributes 1605 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1606 1607 // Create the correct number of regions 1608 if (int numRegions = op.getNumRegions()) { 1609 body << llvm::formatv( 1610 " for (unsigned i = 0; i != {0}; ++i)\n", 1611 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 1612 body << " (void)" << builderOpState << ".addRegion();\n"; 1613 } 1614 1615 // Result types 1616 if (numVariadicResults == 0 || numNonVariadicResults != 0) 1617 body << " assert(resultTypes.size()" 1618 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 1619 << "u && \"mismatched number of return types\");\n"; 1620 body << " " << builderOpState << ".addTypes(resultTypes);\n"; 1621 1622 // Generate builder that infers type too. 1623 // TODO: Expand to handle successors. 1624 if (canInferType(op) && op.getNumSuccessors() == 0) 1625 genInferredTypeCollectiveParamBuilder(); 1626 } 1627 1628 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList, 1629 llvm::StringSet<> &inferredAttributes, 1630 SmallVectorImpl<std::string> &resultTypeNames, 1631 TypeParamKind typeParamKind, 1632 AttrParamKind attrParamKind) { 1633 resultTypeNames.clear(); 1634 auto numResults = op.getNumResults(); 1635 resultTypeNames.reserve(numResults); 1636 1637 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1638 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1639 1640 switch (typeParamKind) { 1641 case TypeParamKind::None: 1642 break; 1643 case TypeParamKind::Separate: { 1644 // Add parameters for all return types 1645 for (int i = 0; i < numResults; ++i) { 1646 const auto &result = op.getResult(i); 1647 std::string resultName = std::string(result.name); 1648 if (resultName.empty()) 1649 resultName = std::string(formatv("resultType{0}", i)); 1650 1651 StringRef type = 1652 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; 1653 1654 paramList.emplace_back(type, resultName, result.isOptional()); 1655 resultTypeNames.emplace_back(std::move(resultName)); 1656 } 1657 } break; 1658 case TypeParamKind::Collective: { 1659 paramList.emplace_back("::mlir::TypeRange", "resultTypes"); 1660 resultTypeNames.push_back("resultTypes"); 1661 } break; 1662 } 1663 1664 // Add parameters for all arguments (operands and attributes). 1665 int defaultValuedAttrStartIndex = op.getNumArgs(); 1666 // Successors and variadic regions go at the end of the parameter list, so no 1667 // default arguments are possible. 1668 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions(); 1669 if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) { 1670 // Calculate the start index from which we can attach default values in the 1671 // builder declaration. 1672 for (int i = op.getNumArgs() - 1; i >= 0; --i) { 1673 auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>(); 1674 if (!namedAttr || !namedAttr->attr.hasDefaultValue()) 1675 break; 1676 1677 if (!canUseUnwrappedRawValue(namedAttr->attr)) 1678 break; 1679 1680 // Creating an APInt requires us to provide bitwidth, value, and 1681 // signedness, which is complicated compared to others. Similarly 1682 // for APFloat. 1683 // TODO: Adjust the 'returnType' field of such attributes 1684 // to support them. 1685 StringRef retType = namedAttr->attr.getReturnType(); 1686 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") 1687 break; 1688 1689 defaultValuedAttrStartIndex = i; 1690 } 1691 } 1692 1693 /// Collect any inferred attributes. 1694 for (const NamedTypeConstraint &operand : op.getOperands()) { 1695 if (operand.isVariadicOfVariadic()) { 1696 inferredAttributes.insert( 1697 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); 1698 } 1699 } 1700 1701 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { 1702 Argument arg = op.getArg(i); 1703 if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) { 1704 StringRef type; 1705 if (operand->isVariadicOfVariadic()) 1706 type = "::llvm::ArrayRef<::mlir::ValueRange>"; 1707 else if (operand->isVariadic()) 1708 type = "::mlir::ValueRange"; 1709 else 1710 type = "::mlir::Value"; 1711 1712 paramList.emplace_back(type, getArgumentName(op, numOperands++), 1713 operand->isOptional()); 1714 continue; 1715 } 1716 const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>(); 1717 const Attribute &attr = namedAttr.attr; 1718 1719 // inferred attributes don't need to be added to the param list. 1720 if (inferredAttributes.contains(namedAttr.name)) 1721 continue; 1722 1723 StringRef type; 1724 switch (attrParamKind) { 1725 case AttrParamKind::WrappedAttr: 1726 type = attr.getStorageType(); 1727 break; 1728 case AttrParamKind::UnwrappedValue: 1729 if (canUseUnwrappedRawValue(attr)) 1730 type = attr.getReturnType(); 1731 else 1732 type = attr.getStorageType(); 1733 break; 1734 } 1735 1736 // Attach default value if requested and possible. 1737 std::string defaultValue; 1738 if (attrParamKind == AttrParamKind::UnwrappedValue && 1739 i >= defaultValuedAttrStartIndex) { 1740 defaultValue += attr.getDefaultValue(); 1741 } 1742 paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue), 1743 attr.isOptional()); 1744 } 1745 1746 /// Insert parameters for each successor. 1747 for (const NamedSuccessor &succ : op.getSuccessors()) { 1748 StringRef type = 1749 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *"; 1750 paramList.emplace_back(type, succ.name); 1751 } 1752 1753 /// Insert parameters for variadic regions. 1754 for (const NamedRegion ®ion : op.getRegions()) 1755 if (region.isVariadic()) 1756 paramList.emplace_back("unsigned", 1757 llvm::formatv("{0}Count", region.name).str()); 1758 } 1759 1760 void OpEmitter::genCodeForAddingArgAndRegionForBuilder( 1761 MethodBody &body, llvm::StringSet<> &inferredAttributes, 1762 bool isRawValueAttr) { 1763 // Push all operands to the result. 1764 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 1765 std::string argName = getArgumentName(op, i); 1766 NamedTypeConstraint &operand = op.getOperand(i); 1767 if (operand.constraint.isVariadicOfVariadic()) { 1768 body << " for (::mlir::ValueRange range : " << argName << ")\n " 1769 << builderOpState << ".addOperands(range);\n"; 1770 1771 // Add the segment attribute. 1772 body << " {\n" 1773 << " ::llvm::SmallVector<int32_t> rangeSegments;\n" 1774 << " for (::mlir::ValueRange range : " << argName << ")\n" 1775 << " rangeSegments.push_back(range.size());\n" 1776 << " " << builderOpState << ".addAttribute(" 1777 << op.getGetterName( 1778 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) 1779 << "AttrName(" << builderOpState << ".name), " << odsBuilder 1780 << ".getI32TensorAttr(rangeSegments));" 1781 << " }\n"; 1782 continue; 1783 } 1784 1785 if (operand.isOptional()) 1786 body << " if (" << argName << ")\n "; 1787 body << " " << builderOpState << ".addOperands(" << argName << ");\n"; 1788 } 1789 1790 // If the operation has the operand segment size attribute, add it here. 1791 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 1792 std::string sizes = op.getGetterName("operand_segment_sizes"); 1793 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" 1794 << builderOpState << ".name), " 1795 << "odsBuilder.getI32VectorAttr({"; 1796 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) { 1797 const NamedTypeConstraint &operand = op.getOperand(i); 1798 if (!operand.isVariableLength()) { 1799 body << "1"; 1800 return; 1801 } 1802 1803 std::string operandName = getArgumentName(op, i); 1804 if (operand.isOptional()) { 1805 body << "(" << operandName << " ? 1 : 0)"; 1806 } else if (operand.isVariadicOfVariadic()) { 1807 body << llvm::formatv( 1808 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, " 1809 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " 1810 "range.size(); }))", 1811 operandName); 1812 } else { 1813 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())"; 1814 } 1815 }); 1816 body << "}));\n"; 1817 } 1818 1819 // Push all attributes to the result. 1820 for (const auto &namedAttr : op.getAttributes()) { 1821 auto &attr = namedAttr.attr; 1822 if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) 1823 continue; 1824 1825 bool emitNotNullCheck = attr.isOptional(); 1826 if (emitNotNullCheck) 1827 body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; 1828 1829 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { 1830 // If this is a raw value, then we need to wrap it in an Attribute 1831 // instance. 1832 FmtContext fctx; 1833 fctx.withBuilder("odsBuilder"); 1834 1835 std::string builderTemplate = std::string(attr.getConstBuilderTemplate()); 1836 1837 // For StringAttr, its constant builder call will wrap the input in 1838 // quotes, which is correct for normal string literals, but incorrect 1839 // here given we use function arguments. So we need to strip the 1840 // wrapping quotes. 1841 if (StringRef(builderTemplate).contains("\"$0\"")) 1842 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); 1843 1844 std::string value = 1845 std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); 1846 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", 1847 builderOpState, op.getGetterName(namedAttr.name), value); 1848 } else { 1849 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", 1850 builderOpState, op.getGetterName(namedAttr.name), 1851 namedAttr.name); 1852 } 1853 if (emitNotNullCheck) 1854 body << " }\n"; 1855 } 1856 1857 // Create the correct number of regions. 1858 for (const NamedRegion ®ion : op.getRegions()) { 1859 if (region.isVariadic()) 1860 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", 1861 region.name); 1862 1863 body << " (void)" << builderOpState << ".addRegion();\n"; 1864 } 1865 1866 // Push all successors to the result. 1867 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { 1868 body << formatv(" {0}.addSuccessors({1});\n", builderOpState, 1869 namedSuccessor.name); 1870 } 1871 } 1872 1873 void OpEmitter::genCanonicalizerDecls() { 1874 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod"); 1875 if (hasCanonicalizeMethod) { 1876 // static LogicResult FooOp:: 1877 // canonicalize(FooOp op, PatternRewriter &rewriter); 1878 SmallVector<MethodParameter> paramList; 1879 paramList.emplace_back(op.getCppClassName(), "op"); 1880 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); 1881 auto *m = opClass.declareStaticMethod("::mlir::LogicalResult", 1882 "canonicalize", std::move(paramList)); 1883 ERROR_IF_PRUNED(m, "canonicalize", op); 1884 } 1885 1886 // We get a prototype for 'getCanonicalizationPatterns' if requested directly 1887 // or if using a 'canonicalize' method. 1888 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer"); 1889 if (!hasCanonicalizeMethod && !hasCanonicalizer) 1890 return; 1891 1892 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize' 1893 // method, but not implementing 'getCanonicalizationPatterns' manually. 1894 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer; 1895 1896 // Add a signature for getCanonicalizationPatterns if implemented by the 1897 // dialect or if synthesized to call 'canonicalize'. 1898 SmallVector<MethodParameter> paramList; 1899 paramList.emplace_back("::mlir::RewritePatternSet &", "results"); 1900 paramList.emplace_back("::mlir::MLIRContext *", "context"); 1901 auto kind = hasBody ? Method::Static : Method::StaticDeclaration; 1902 auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind, 1903 std::move(paramList)); 1904 1905 // If synthesizing the method, fill it it. 1906 if (hasBody) { 1907 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op); 1908 method->body() << " results.add(canonicalize);\n"; 1909 } 1910 } 1911 1912 void OpEmitter::genFolderDecls() { 1913 bool hasSingleResult = 1914 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; 1915 1916 if (def.getValueAsBit("hasFolder")) { 1917 if (hasSingleResult) { 1918 auto *m = opClass.declareMethod( 1919 "::mlir::OpFoldResult", "fold", 1920 MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands")); 1921 ERROR_IF_PRUNED(m, "operands", op); 1922 } else { 1923 SmallVector<MethodParameter> paramList; 1924 paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); 1925 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", 1926 "results"); 1927 auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold", 1928 std::move(paramList)); 1929 ERROR_IF_PRUNED(m, "fold", op); 1930 } 1931 } 1932 } 1933 1934 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { 1935 Interface interface = opTrait->getInterface(); 1936 1937 // Get the set of methods that should always be declared. 1938 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); 1939 llvm::StringSet<> alwaysDeclaredMethods; 1940 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), 1941 alwaysDeclaredMethodsVec.end()); 1942 1943 for (const InterfaceMethod &method : interface.getMethods()) { 1944 // Don't declare if the method has a body. 1945 if (method.getBody()) 1946 continue; 1947 // Don't declare if the method has a default implementation and the op 1948 // didn't request that it always be declared. 1949 if (method.getDefaultImplementation() && 1950 !alwaysDeclaredMethods.count(method.getName())) 1951 continue; 1952 // Interface methods are allowed to overlap with existing methods, so don't 1953 // check if pruned. 1954 (void)genOpInterfaceMethod(method); 1955 } 1956 } 1957 1958 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, 1959 bool declaration) { 1960 SmallVector<MethodParameter> paramList; 1961 for (const InterfaceMethod::Argument &arg : method.getArguments()) 1962 paramList.emplace_back(arg.type, arg.name); 1963 1964 auto props = (method.isStatic() ? Method::Static : Method::None) | 1965 (declaration ? Method::Declaration : Method::None); 1966 return opClass.addMethod(method.getReturnType(), method.getName(), props, 1967 std::move(paramList)); 1968 } 1969 1970 void OpEmitter::genOpInterfaceMethods() { 1971 for (const auto &trait : op.getTraits()) { 1972 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) 1973 if (opTrait->shouldDeclareMethods()) 1974 genOpInterfaceMethods(opTrait); 1975 } 1976 } 1977 1978 void OpEmitter::genSideEffectInterfaceMethods() { 1979 enum EffectKind { Operand, Result, Symbol, Static }; 1980 struct EffectLocation { 1981 /// The effect applied. 1982 SideEffect effect; 1983 1984 /// The index if the kind is not static. 1985 unsigned index; 1986 1987 /// The kind of the location. 1988 unsigned kind; 1989 }; 1990 1991 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects; 1992 auto resolveDecorators = [&](Operator::var_decorator_range decorators, 1993 unsigned index, unsigned kind) { 1994 for (auto decorator : decorators) 1995 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) { 1996 opClass.addTrait(effect->getInterfaceTrait()); 1997 interfaceEffects[effect->getBaseEffectName()].push_back( 1998 EffectLocation{*effect, index, kind}); 1999 } 2000 }; 2001 2002 // Collect effects that were specified via: 2003 /// Traits. 2004 for (const auto &trait : op.getTraits()) { 2005 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait); 2006 if (!opTrait) 2007 continue; 2008 auto &effects = interfaceEffects[opTrait->getBaseEffectName()]; 2009 for (auto decorator : opTrait->getEffects()) 2010 effects.push_back(EffectLocation{cast<SideEffect>(decorator), 2011 /*index=*/0, EffectKind::Static}); 2012 } 2013 /// Attributes and Operands. 2014 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { 2015 Argument arg = op.getArg(i); 2016 if (arg.is<NamedTypeConstraint *>()) { 2017 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); 2018 ++operandIt; 2019 continue; 2020 } 2021 const NamedAttribute *attr = arg.get<NamedAttribute *>(); 2022 if (attr->attr.getBaseAttr().isSymbolRefAttr()) 2023 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol); 2024 } 2025 /// Results. 2026 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) 2027 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); 2028 2029 // The code used to add an effect instance. 2030 // {0}: The effect class. 2031 // {1}: Optional value or symbol reference. 2032 // {1}: The resource class. 2033 const char *addEffectCode = 2034 " effects.emplace_back({0}::get(), {1}{2}::get());\n"; 2035 2036 for (auto &it : interfaceEffects) { 2037 // Generate the 'getEffects' method. 2038 std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::" 2039 "SideEffects::EffectInstance<{0}>> &", 2040 it.first()) 2041 .str(); 2042 auto *getEffects = opClass.addMethod("void", "getEffects", 2043 MethodParameter(type, "effects")); 2044 ERROR_IF_PRUNED(getEffects, "getEffects", op); 2045 auto &body = getEffects->body(); 2046 2047 // Add effect instances for each of the locations marked on the operation. 2048 for (auto &location : it.second) { 2049 StringRef effect = location.effect.getName(); 2050 StringRef resource = location.effect.getResource(); 2051 if (location.kind == EffectKind::Static) { 2052 // A static instance has no attached value. 2053 body << llvm::formatv(addEffectCode, effect, "", resource).str(); 2054 } else if (location.kind == EffectKind::Symbol) { 2055 // A symbol reference requires adding the proper attribute. 2056 const auto *attr = op.getArg(location.index).get<NamedAttribute *>(); 2057 std::string argName = op.getGetterName(attr->name); 2058 if (attr->attr.isOptional()) { 2059 body << " if (auto symbolRef = " << argName << "Attr())\n " 2060 << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource) 2061 .str(); 2062 } else { 2063 body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ", 2064 resource) 2065 .str(); 2066 } 2067 } else { 2068 // Otherwise this is an operand/result, so we need to attach the Value. 2069 body << " for (::mlir::Value value : getODS" 2070 << (location.kind == EffectKind::Operand ? "Operands" : "Results") 2071 << "(" << location.index << "))\n " 2072 << llvm::formatv(addEffectCode, effect, "value, ", resource).str(); 2073 } 2074 } 2075 } 2076 } 2077 2078 void OpEmitter::genTypeInterfaceMethods() { 2079 if (!op.allResultTypesKnown()) 2080 return; 2081 // Generate 'inferReturnTypes' method declaration using the interface method 2082 // declared in 'InferTypeOpInterface' op interface. 2083 const auto *trait = 2084 cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait")); 2085 Interface interface = trait->getInterface(); 2086 Method *method = [&]() -> Method * { 2087 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { 2088 if (interfaceMethod.getName() == "inferReturnTypes") { 2089 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); 2090 } 2091 } 2092 assert(0 && "unable to find inferReturnTypes interface method"); 2093 return nullptr; 2094 }(); 2095 ERROR_IF_PRUNED(method, "inferReturnTypes", op); 2096 auto &body = method->body(); 2097 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; 2098 2099 FmtContext fctx; 2100 fctx.withBuilder("odsBuilder"); 2101 body << " ::mlir::Builder odsBuilder(context);\n"; 2102 2103 auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & { 2104 if (!type.isArg()) 2105 return body << tgfmt(*type.getType().getBuilderCall(), &fctx); 2106 auto argIndex = type.getArg(); 2107 assert(!op.getArg(argIndex).is<NamedAttribute *>()); 2108 auto arg = op.getArgToOperandOrAttribute(argIndex); 2109 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) 2110 return body << "operands[" << arg.operandOrAttributeIndex() 2111 << "].getType()"; 2112 return body << "attributes[" << arg.operandOrAttributeIndex() 2113 << "].getType()"; 2114 }; 2115 2116 for (int i = 0, e = op.getNumResults(); i != e; ++i) { 2117 body << " inferredReturnTypes[" << i << "] = "; 2118 auto types = op.getSameTypeAsResult(i); 2119 emitType(types[0]) << ";\n"; 2120 if (types.size() == 1) 2121 continue; 2122 // TODO: We could verify equality here, but skipping that for verification. 2123 } 2124 body << " return ::mlir::success();"; 2125 } 2126 2127 void OpEmitter::genParser() { 2128 if (!hasStringAttribute(def, "parser") || 2129 hasStringAttribute(def, "assemblyFormat")) 2130 return; 2131 2132 SmallVector<MethodParameter> paramList; 2133 paramList.emplace_back("::mlir::OpAsmParser &", "parser"); 2134 paramList.emplace_back("::mlir::OperationState &", "result"); 2135 auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse", 2136 std::move(paramList)); 2137 ERROR_IF_PRUNED(method, "parse", op); 2138 2139 FmtContext fctx; 2140 fctx.addSubst("cppClass", opClass.getClassName()); 2141 auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); 2142 method->body() << " " << tgfmt(parser, &fctx); 2143 } 2144 2145 void OpEmitter::genPrinter() { 2146 if (hasStringAttribute(def, "assemblyFormat")) 2147 return; 2148 2149 auto *valueInit = def.getValueInit("printer"); 2150 StringInit *stringInit = dyn_cast<StringInit>(valueInit); 2151 if (!stringInit) 2152 return; 2153 2154 auto *method = opClass.addMethod( 2155 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p")); 2156 ERROR_IF_PRUNED(method, "print", op); 2157 FmtContext fctx; 2158 fctx.addSubst("cppClass", opClass.getClassName()); 2159 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r"); 2160 method->body() << " " << tgfmt(printer, &fctx); 2161 } 2162 2163 /// Generate verification on native traits requiring attributes. 2164 static void genNativeTraitAttrVerifier(MethodBody &body, 2165 const OpOrAdaptorHelper &emitHelper) { 2166 // Check that the variadic segment sizes attribute exists and contains the 2167 // expected number of elements. 2168 // 2169 // {0}: Attribute name. 2170 // {1}: Expected number of elements. 2171 // {2}: "operand" or "result". 2172 // {3}: Attribute getter call. 2173 // {4}: Emit error prefix. 2174 const char *const checkAttrSizedValueSegmentsCode = R"( 2175 { 2176 auto sizeAttr = {3}.dyn_cast<::mlir::DenseIntElementsAttr>(); 2177 if (!sizeAttr) 2178 return {4}"missing segment sizes attribute '{0}'"); 2179 auto numElements = 2180 sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); 2181 if (numElements != {1}) 2182 return {4}"'{0}' attribute for specifying {2} segments must have {1} " 2183 "elements, but got ") << numElements; 2184 } 2185 )"; 2186 2187 // Verify a few traits first so that we can use getODSOperands() and 2188 // getODSResults() in the rest of the verifier. 2189 auto &op = emitHelper.getOp(); 2190 for (auto &trait : op.getTraits()) { 2191 auto *t = dyn_cast<tblgen::NativeTrait>(&trait); 2192 if (!t) 2193 continue; 2194 std::string traitName = t->getFullyQualifiedTraitName(); 2195 if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") { 2196 StringRef attrName = "operand_segment_sizes"; 2197 body << formatv(checkAttrSizedValueSegmentsCode, attrName, 2198 op.getNumOperands(), "operand", 2199 emitHelper.getAttr(attrName), 2200 emitHelper.emitErrorPrefix()); 2201 } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") { 2202 StringRef attrName = "result_segment_sizes"; 2203 body << formatv( 2204 checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(), 2205 "result", emitHelper.getAttr(attrName), emitHelper.emitErrorPrefix()); 2206 } 2207 } 2208 } 2209 2210 void OpEmitter::genVerifier() { 2211 auto *method = opClass.addMethod("::mlir::LogicalResult", "verify"); 2212 ERROR_IF_PRUNED(method, "verify", op); 2213 auto &body = method->body(); 2214 2215 OpOrAdaptorHelper emitHelper(op, /*isOp=*/true); 2216 genNativeTraitAttrVerifier(body, emitHelper); 2217 2218 auto *valueInit = def.getValueInit("verifier"); 2219 StringInit *stringInit = dyn_cast<StringInit>(valueInit); 2220 bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); 2221 populateSubstitutions(emitHelper, verifyCtx); 2222 2223 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); 2224 genOperandResultVerifier(body, op.getOperands(), "operand"); 2225 genOperandResultVerifier(body, op.getResults(), "result"); 2226 2227 for (auto &trait : op.getTraits()) { 2228 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) { 2229 body << tgfmt(" if (!($0))\n " 2230 "return emitOpError(\"failed to verify that $1\");\n", 2231 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), 2232 t->getSummary()); 2233 } 2234 } 2235 2236 genRegionVerifier(body); 2237 genSuccessorVerifier(body); 2238 2239 if (hasCustomVerify) { 2240 FmtContext fctx; 2241 fctx.addSubst("cppClass", opClass.getClassName()); 2242 auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r"); 2243 body << " " << tgfmt(printer, &fctx); 2244 } else { 2245 body << " return ::mlir::success();\n"; 2246 } 2247 } 2248 2249 void OpEmitter::genOperandResultVerifier(MethodBody &body, 2250 Operator::const_value_range values, 2251 StringRef valueKind) { 2252 // Check that an optional value is at most 1 element. 2253 // 2254 // {0}: Value index. 2255 // {1}: "operand" or "result" 2256 const char *const verifyOptional = R"( 2257 if (valueGroup{0}.size() > 1) { 2258 return emitOpError("{1} group starting at #") << index 2259 << " requires 0 or 1 element, but found " << valueGroup{0}.size(); 2260 } 2261 )"; 2262 // Check the types of a range of values. 2263 // 2264 // {0}: Value index. 2265 // {1}: Type constraint function. 2266 // {2}: "operand" or "result" 2267 const char *const verifyValues = R"( 2268 for (auto v : valueGroup{0}) { 2269 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++))) 2270 return ::mlir::failure(); 2271 } 2272 )"; 2273 2274 const auto canSkip = [](const NamedTypeConstraint &value) { 2275 return !value.hasPredicate() && !value.isOptional() && 2276 !value.isVariadicOfVariadic(); 2277 }; 2278 if (values.empty() || llvm::all_of(values, canSkip)) 2279 return; 2280 2281 FmtContext fctx; 2282 2283 body << " {\n unsigned index = 0; (void)index;\n"; 2284 2285 for (const auto &staticValue : llvm::enumerate(values)) { 2286 const NamedTypeConstraint &value = staticValue.value(); 2287 2288 bool hasPredicate = value.hasPredicate(); 2289 bool isOptional = value.isOptional(); 2290 bool isVariadicOfVariadic = value.isVariadicOfVariadic(); 2291 if (!hasPredicate && !isOptional && !isVariadicOfVariadic) 2292 continue; 2293 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", 2294 // Capitalize the first letter to match the function name 2295 valueKind.substr(0, 1).upper(), valueKind.substr(1), 2296 staticValue.index()); 2297 2298 // If the constraint is optional check that the value group has at most 1 2299 // value. 2300 if (isOptional) { 2301 body << formatv(verifyOptional, staticValue.index(), valueKind); 2302 } else if (isVariadicOfVariadic) { 2303 body << formatv( 2304 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" 2305 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n" 2306 " return ::mlir::failure();\n", 2307 value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name, 2308 staticValue.index()); 2309 } 2310 2311 // Otherwise, if there is no predicate there is nothing left to do. 2312 if (!hasPredicate) 2313 continue; 2314 // Emit a loop to check all the dynamic values in the pack. 2315 StringRef constraintFn = 2316 staticVerifierEmitter.getTypeConstraintFn(value.constraint); 2317 body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind); 2318 } 2319 2320 body << " }\n"; 2321 } 2322 2323 void OpEmitter::genRegionVerifier(MethodBody &body) { 2324 /// Code to verify a region. 2325 /// 2326 /// {0}: Getter for the regions. 2327 /// {1}: The region constraint. 2328 /// {2}: The region's name. 2329 /// {3}: The region description. 2330 const char *const verifyRegion = R"( 2331 for (auto ®ion : {0}) 2332 if (::mlir::failed({1}(*this, region, "{2}", index++))) 2333 return ::mlir::failure(); 2334 )"; 2335 /// Get a single region. 2336 /// 2337 /// {0}: The region's index. 2338 const char *const getSingleRegion = 2339 "::llvm::makeMutableArrayRef((*this)->getRegion({0}))"; 2340 2341 // If we have no regions, there is nothing more to do. 2342 const auto canSkip = [](const NamedRegion ®ion) { 2343 return region.constraint.getPredicate().isNull(); 2344 }; 2345 auto regions = op.getRegions(); 2346 if (regions.empty() && llvm::all_of(regions, canSkip)) 2347 return; 2348 2349 body << " {\n unsigned index = 0; (void)index;\n"; 2350 for (const auto &it : llvm::enumerate(regions)) { 2351 const auto ®ion = it.value(); 2352 if (canSkip(region)) 2353 continue; 2354 2355 auto getRegion = region.isVariadic() 2356 ? formatv("{0}()", op.getGetterName(region.name)).str() 2357 : formatv(getSingleRegion, it.index()).str(); 2358 auto constraintFn = 2359 staticVerifierEmitter.getRegionConstraintFn(region.constraint); 2360 body << formatv(verifyRegion, getRegion, constraintFn, region.name); 2361 } 2362 body << " }\n"; 2363 } 2364 2365 void OpEmitter::genSuccessorVerifier(MethodBody &body) { 2366 const char *const verifySuccessor = R"( 2367 for (auto *successor : {0}) 2368 if (::mlir::failed({1}(*this, successor, "{2}", index++))) 2369 return ::mlir::failure(); 2370 )"; 2371 /// Get a single successor. 2372 /// 2373 /// {0}: The successor's name. 2374 const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())"; 2375 2376 // If we have no successors, there is nothing more to do. 2377 const auto canSkip = [](const NamedSuccessor &successor) { 2378 return successor.constraint.getPredicate().isNull(); 2379 }; 2380 auto successors = op.getSuccessors(); 2381 if (successors.empty() && llvm::all_of(successors, canSkip)) 2382 return; 2383 2384 body << " {\n unsigned index = 0; (void)index;\n"; 2385 2386 for (auto &it : llvm::enumerate(successors)) { 2387 const auto &successor = it.value(); 2388 if (canSkip(successor)) 2389 continue; 2390 2391 auto getSuccessor = 2392 formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor, 2393 successor.name, it.index()) 2394 .str(); 2395 auto constraintFn = 2396 staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint); 2397 body << formatv(verifySuccessor, getSuccessor, constraintFn, 2398 successor.name); 2399 } 2400 body << " }\n"; 2401 } 2402 2403 /// Add a size count trait to the given operation class. 2404 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, 2405 int numTotal, int numVariadic) { 2406 if (numVariadic != 0) { 2407 if (numTotal == numVariadic) 2408 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s"); 2409 else 2410 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" + 2411 Twine(numTotal - numVariadic) + ">::Impl"); 2412 return; 2413 } 2414 switch (numTotal) { 2415 case 0: 2416 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind); 2417 break; 2418 case 1: 2419 opClass.addTrait("::mlir::OpTrait::One" + traitKind); 2420 break; 2421 default: 2422 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) + 2423 ">::Impl"); 2424 break; 2425 } 2426 } 2427 2428 void OpEmitter::genTraits() { 2429 // Add region size trait. 2430 unsigned numRegions = op.getNumRegions(); 2431 unsigned numVariadicRegions = op.getNumVariadicRegions(); 2432 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); 2433 2434 // Add result size traits. 2435 int numResults = op.getNumResults(); 2436 int numVariadicResults = op.getNumVariableLengthResults(); 2437 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); 2438 2439 // For single result ops with a known specific type, generate a OneTypedResult 2440 // trait. 2441 if (numResults == 1 && numVariadicResults == 0) { 2442 auto cppName = op.getResults().begin()->constraint.getCPPClassName(); 2443 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); 2444 } 2445 2446 // Add successor size trait. 2447 unsigned numSuccessors = op.getNumSuccessors(); 2448 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); 2449 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); 2450 2451 // Add variadic size trait and normal op traits. 2452 int numOperands = op.getNumOperands(); 2453 int numVariadicOperands = op.getNumVariableLengthOperands(); 2454 2455 // Add operand size trait. 2456 if (numVariadicOperands != 0) { 2457 if (numOperands == numVariadicOperands) 2458 opClass.addTrait("::mlir::OpTrait::VariadicOperands"); 2459 else 2460 opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" + 2461 Twine(numOperands - numVariadicOperands) + ">::Impl"); 2462 } else { 2463 switch (numOperands) { 2464 case 0: 2465 opClass.addTrait("::mlir::OpTrait::ZeroOperands"); 2466 break; 2467 case 1: 2468 opClass.addTrait("::mlir::OpTrait::OneOperand"); 2469 break; 2470 default: 2471 opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) + 2472 ">::Impl"); 2473 break; 2474 } 2475 } 2476 2477 // Add the native and interface traits. 2478 for (const auto &trait : op.getTraits()) { 2479 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) 2480 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 2481 else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) 2482 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 2483 } 2484 } 2485 2486 void OpEmitter::genOpNameGetter() { 2487 auto *method = opClass.addStaticMethod<Method::Constexpr>( 2488 "::llvm::StringLiteral", "getOperationName"); 2489 ERROR_IF_PRUNED(method, "getOperationName", op); 2490 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() 2491 << "\");"; 2492 } 2493 2494 void OpEmitter::genOpAsmInterface() { 2495 // If the user only has one results or specifically added the Asm trait, 2496 // then don't generate it for them. We specifically only handle multi result 2497 // operations, because the name of a single result in the common case is not 2498 // interesting(generally 'result'/'output'/etc.). 2499 // TODO: We could also add a flag to allow operations to opt in to this 2500 // generation, even if they only have a single operation. 2501 int numResults = op.getNumResults(); 2502 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait")) 2503 return; 2504 2505 SmallVector<StringRef, 4> resultNames(numResults); 2506 for (int i = 0; i != numResults; ++i) 2507 resultNames[i] = op.getResultName(i); 2508 2509 // Don't add the trait if none of the results have a valid name. 2510 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) 2511 return; 2512 opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); 2513 2514 // Generate the right accessor for the number of results. 2515 auto *method = opClass.addMethod( 2516 "void", "getAsmResultNames", 2517 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn")); 2518 ERROR_IF_PRUNED(method, "getAsmResultNames", op); 2519 auto &body = method->body(); 2520 for (int i = 0; i != numResults; ++i) { 2521 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" 2522 << " if (!llvm::empty(resultGroup" << i << "))\n" 2523 << " setNameFn(*resultGroup" << i << ".begin(), \"" 2524 << resultNames[i] << "\");\n"; 2525 } 2526 } 2527 2528 //===----------------------------------------------------------------------===// 2529 // OpOperandAdaptor emitter 2530 //===----------------------------------------------------------------------===// 2531 2532 namespace { 2533 // Helper class to emit Op operand adaptors to an output stream. Operand 2534 // adaptors are wrappers around ArrayRef<Value> that provide named operand 2535 // getters identical to those defined in the Op. 2536 class OpOperandAdaptorEmitter { 2537 public: 2538 static void emitDecl(const Operator &op, 2539 StaticVerifierFunctionEmitter &staticVerifierEmitter, 2540 raw_ostream &os); 2541 static void emitDef(const Operator &op, 2542 StaticVerifierFunctionEmitter &staticVerifierEmitter, 2543 raw_ostream &os); 2544 2545 private: 2546 explicit OpOperandAdaptorEmitter( 2547 const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter); 2548 2549 // Add verification function. This generates a verify method for the adaptor 2550 // which verifies all the op-independent attribute constraints. 2551 void addVerification(); 2552 2553 const Operator &op; 2554 Class adaptor; 2555 StaticVerifierFunctionEmitter &staticVerifierEmitter; 2556 }; 2557 } // namespace 2558 2559 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( 2560 const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter) 2561 : op(op), adaptor(op.getAdaptorName()), 2562 staticVerifierEmitter(staticVerifierEmitter) { 2563 adaptor.addField("::mlir::ValueRange", "odsOperands"); 2564 adaptor.addField("::mlir::DictionaryAttr", "odsAttrs"); 2565 adaptor.addField("::mlir::RegionRange", "odsRegions"); 2566 const auto *attrSizedOperands = 2567 op.getTrait("::m::OpTrait::AttrSizedOperandSegments"); 2568 { 2569 SmallVector<MethodParameter> paramList; 2570 paramList.emplace_back("::mlir::ValueRange", "values"); 2571 paramList.emplace_back("::mlir::DictionaryAttr", "attrs", 2572 attrSizedOperands ? "" : "nullptr"); 2573 paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); 2574 auto *constructor = adaptor.addConstructor(std::move(paramList)); 2575 2576 constructor->addMemberInitializer("odsOperands", "values"); 2577 constructor->addMemberInitializer("odsAttrs", "attrs"); 2578 constructor->addMemberInitializer("odsRegions", "regions"); 2579 } 2580 2581 { 2582 auto *constructor = adaptor.addConstructor( 2583 MethodParameter(op.getCppClassName() + " &", "op")); 2584 constructor->addMemberInitializer("odsOperands", "op->getOperands()"); 2585 constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); 2586 constructor->addMemberInitializer("odsRegions", "op->getRegions()"); 2587 } 2588 2589 { 2590 auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"); 2591 ERROR_IF_PRUNED(m, "getOperands", op); 2592 m->body() << " return odsOperands;"; 2593 } 2594 std::string attr = "operand_segment_sizes"; 2595 std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr); 2596 generateNamedOperandGetters(op, adaptor, 2597 /*isAdaptor=*/true, sizeAttrInit, 2598 /*rangeType=*/"::mlir::ValueRange", 2599 /*rangeBeginCall=*/"odsOperands.begin()", 2600 /*rangeSizeCall=*/"odsOperands.size()", 2601 /*getOperandCallPattern=*/"odsOperands[{0}]"); 2602 2603 FmtContext fctx; 2604 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); 2605 2606 // Generate named accessor with Attribute return type. 2607 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName, 2608 Attribute attr) { 2609 auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr"); 2610 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); 2611 auto &body = method->body(); 2612 body << " assert(odsAttrs && \"no attributes when constructing adapter\");" 2613 << "\n " << attr.getStorageType() << " attr = " 2614 << "odsAttrs.get(\"" << name << "\")."; 2615 if (attr.hasDefaultValue() || attr.isOptional()) 2616 body << "dyn_cast_or_null<"; 2617 else 2618 body << "cast<"; 2619 body << attr.getStorageType() << ">();\n"; 2620 2621 if (attr.hasDefaultValue()) { 2622 // Use the default value if attribute is not set. 2623 // TODO: this is inefficient, we are recreating the attribute for every 2624 // call. This should be set instead. 2625 std::string defaultValue = std::string( 2626 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 2627 body << " if (!attr)\n attr = " << defaultValue << ";\n"; 2628 } 2629 body << " return attr;\n"; 2630 }; 2631 2632 { 2633 auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes"); 2634 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); 2635 m->body() << " return odsAttrs;"; 2636 } 2637 for (auto &namedAttr : op.getAttributes()) { 2638 const auto &name = namedAttr.name; 2639 const auto &attr = namedAttr.attr; 2640 if (attr.isDerivedAttr()) 2641 continue; 2642 for (const auto &emitName : op.getGetterNames(name)) { 2643 emitAttrWithStorageType(name, emitName, attr); 2644 emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr); 2645 } 2646 } 2647 2648 unsigned numRegions = op.getNumRegions(); 2649 if (numRegions > 0) { 2650 auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"); 2651 ERROR_IF_PRUNED(m, "Adaptor::getRegions", op); 2652 m->body() << " return odsRegions;"; 2653 } 2654 for (unsigned i = 0; i < numRegions; ++i) { 2655 const auto ®ion = op.getRegion(i); 2656 if (region.name.empty()) 2657 continue; 2658 2659 // Generate the accessors for a variadic region. 2660 for (StringRef name : op.getGetterNames(region.name)) { 2661 if (region.isVariadic()) { 2662 auto *m = adaptor.addMethod("::mlir::RegionRange", name); 2663 ERROR_IF_PRUNED(m, "Adaptor::" + name, op); 2664 m->body() << formatv(" return odsRegions.drop_front({0});", i); 2665 continue; 2666 } 2667 2668 auto *m = adaptor.addMethod("::mlir::Region &", name); 2669 ERROR_IF_PRUNED(m, "Adaptor::" + name, op); 2670 m->body() << formatv(" return *odsRegions[{0}];", i); 2671 } 2672 } 2673 2674 // Add verification function. 2675 addVerification(); 2676 adaptor.finalize(); 2677 } 2678 2679 void OpOperandAdaptorEmitter::addVerification() { 2680 auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify", 2681 MethodParameter("::mlir::Location", "loc")); 2682 ERROR_IF_PRUNED(method, "verify", op); 2683 auto &body = method->body(); 2684 2685 OpOrAdaptorHelper emitHelper(op, /*isOp=*/false); 2686 genNativeTraitAttrVerifier(body, emitHelper); 2687 FmtContext verifyCtx; 2688 populateSubstitutions(emitHelper, verifyCtx); 2689 2690 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); 2691 2692 body << " return ::mlir::success();"; 2693 } 2694 2695 void OpOperandAdaptorEmitter::emitDecl( 2696 const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, 2697 raw_ostream &os) { 2698 OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); 2699 } 2700 2701 void OpOperandAdaptorEmitter::emitDef( 2702 const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, 2703 raw_ostream &os) { 2704 OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); 2705 } 2706 2707 // Emits the opcode enum and op classes. 2708 static void emitOpClasses(const RecordKeeper &recordKeeper, 2709 const std::vector<Record *> &defs, raw_ostream &os, 2710 bool emitDecl) { 2711 // First emit forward declaration for each class, this allows them to refer 2712 // to each others in traits for example. 2713 if (emitDecl) { 2714 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; 2715 os << "#undef GET_OP_FWD_DEFINES\n"; 2716 for (auto *def : defs) { 2717 Operator op(*def); 2718 NamespaceEmitter emitter(os, op.getCppNamespace()); 2719 os << "class " << op.getCppClassName() << ";\n"; 2720 } 2721 os << "#endif\n\n"; 2722 } 2723 2724 IfDefScope scope("GET_OP_CLASSES", os); 2725 if (defs.empty()) 2726 return; 2727 2728 // Generate all of the locally instantiated methods first. 2729 StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); 2730 os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); 2731 staticVerifierEmitter.emitOpConstraints(defs, emitDecl); 2732 2733 for (auto *def : defs) { 2734 Operator op(*def); 2735 if (emitDecl) { 2736 { 2737 NamespaceEmitter emitter(os, op.getCppNamespace()); 2738 os << formatv(opCommentHeader, op.getQualCppClassName(), 2739 "declarations"); 2740 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); 2741 OpEmitter::emitDecl(op, os, staticVerifierEmitter); 2742 } 2743 // Emit the TypeID explicit specialization to have a single definition. 2744 if (!op.getCppNamespace().empty()) 2745 os << "DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() 2746 << "::" << op.getCppClassName() << ")\n\n"; 2747 } else { 2748 { 2749 NamespaceEmitter emitter(os, op.getCppNamespace()); 2750 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); 2751 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os); 2752 OpEmitter::emitDef(op, os, staticVerifierEmitter); 2753 } 2754 // Emit the TypeID explicit specialization to have a single definition. 2755 if (!op.getCppNamespace().empty()) 2756 os << "DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() 2757 << "::" << op.getCppClassName() << ")\n\n"; 2758 } 2759 } 2760 } 2761 2762 // Emits a comma-separated list of the ops. 2763 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) { 2764 IfDefScope scope("GET_OP_LIST", os); 2765 2766 interleave( 2767 // TODO: We are constructing the Operator wrapper instance just for 2768 // getting it's qualified class name here. Reduce the overhead by having a 2769 // lightweight version of Operator class just for that purpose. 2770 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, 2771 [&os]() { os << ",\n"; }); 2772 } 2773 2774 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { 2775 emitSourceFileHeader("Op Declarations", os); 2776 2777 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper); 2778 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); 2779 2780 return false; 2781 } 2782 2783 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { 2784 emitSourceFileHeader("Op Definitions", os); 2785 2786 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper); 2787 emitOpList(defs, os); 2788 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); 2789 2790 return false; 2791 } 2792 2793 static mlir::GenRegistration 2794 genOpDecls("gen-op-decls", "Generate op declarations", 2795 [](const RecordKeeper &records, raw_ostream &os) { 2796 return emitOpDecls(records, os); 2797 }); 2798 2799 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", 2800 [](const RecordKeeper &records, 2801 raw_ostream &os) { 2802 return emitOpDefs(records, os); 2803 }); 2804