1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // OpDefinitionsGen uses the description of operations to generate C++ 10 // definitions for ops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "OpClass.h" 15 #include "OpFormatGen.h" 16 #include "OpGenHelpers.h" 17 #include "mlir/TableGen/Class.h" 18 #include "mlir/TableGen/CodeGenHelpers.h" 19 #include "mlir/TableGen/Format.h" 20 #include "mlir/TableGen/GenInfo.h" 21 #include "mlir/TableGen/Interfaces.h" 22 #include "mlir/TableGen/Operator.h" 23 #include "mlir/TableGen/SideEffects.h" 24 #include "mlir/TableGen/Trait.h" 25 #include "llvm/ADT/MapVector.h" 26 #include "llvm/ADT/Sequence.h" 27 #include "llvm/ADT/StringExtras.h" 28 #include "llvm/ADT/StringSet.h" 29 #include "llvm/ADT/StringSwitch.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/Signals.h" 32 #include "llvm/TableGen/Error.h" 33 #include "llvm/TableGen/Record.h" 34 #include "llvm/TableGen/TableGenBackend.h" 35 36 #define DEBUG_TYPE "mlir-tblgen-opdefgen" 37 38 using namespace llvm; 39 using namespace mlir; 40 using namespace mlir::tblgen; 41 42 static const char *const tblgenNamePrefix = "tblgen_"; 43 static const char *const generatedArgName = "odsArg"; 44 static const char *const odsBuilder = "odsBuilder"; 45 static const char *const builderOpState = "odsState"; 46 47 /// The names of the implicit attributes that contain variadic operand and 48 /// result segment sizes. 49 static const char *const operandSegmentAttrName = "operand_segment_sizes"; 50 static const char *const resultSegmentAttrName = "result_segment_sizes"; 51 52 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange 53 /// lookup. 54 /// 55 /// {0}: Code snippet to get the attribute's name or identifier. 56 /// {1}: The lower bound on the sorted subrange. 57 /// {2}: The upper bound on the sorted subrange. 58 /// {3}: Code snippet to get the array of named attributes. 59 /// {4}: "Named" to get the named attribute. 60 static const char *const subrangeGetAttr = 61 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - " 62 "{2}, {0})"; 63 64 /// The logic to calculate the actual value range for a declared operand/result 65 /// of an op with variadic operands/results. Note that this logic is not for 66 /// general use; it assumes all variadic operands/results must have the same 67 /// number of values. 68 /// 69 /// {0}: The list of whether each declared operand/result is variadic. 70 /// {1}: The total number of non-variadic operands/results. 71 /// {2}: The total number of variadic operands/results. 72 /// {3}: The total number of actual values. 73 /// {4}: "operand" or "result". 74 static const char *const sameVariadicSizeValueRangeCalcCode = R"( 75 bool isVariadic[] = {{{0}}; 76 int prevVariadicCount = 0; 77 for (unsigned i = 0; i < index; ++i) 78 if (isVariadic[i]) ++prevVariadicCount; 79 80 // Calculate how many dynamic values a static variadic {4} corresponds to. 81 // This assumes all static variadic {4}s have the same dynamic value count. 82 int variadicSize = ({3} - {1}) / {2}; 83 // `index` passed in as the parameter is the static index which counts each 84 // {4} (variadic or not) as size 1. So here for each previous static variadic 85 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic 86 // value pack for this static {4} starts. 87 int start = index + (variadicSize - 1) * prevVariadicCount; 88 int size = isVariadic[index] ? variadicSize : 1; 89 return {{start, size}; 90 )"; 91 92 /// The logic to calculate the actual value range for a declared operand/result 93 /// of an op with variadic operands/results. Note that this logic is assumes 94 /// the op has an attribute specifying the size of each operand/result segment 95 /// (variadic or not). 96 static const char *const attrSizedSegmentValueRangeCalcCode = R"( 97 const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>(); 98 if (sizeAttr.isSplat()) 99 return {*sizeAttrValueIt * index, *sizeAttrValueIt}; 100 101 unsigned start = 0; 102 for (unsigned i = 0; i < index; ++i) 103 start += sizeAttrValueIt[i]; 104 return {start, sizeAttrValueIt[index]}; 105 )"; 106 /// The code snippet to initialize the sizes for the value range calculation. 107 /// 108 /// {0}: The code to get the attribute. 109 static const char *const adapterSegmentSizeAttrInitCode = R"( 110 assert(odsAttrs && "missing segment size attribute for op"); 111 auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); 112 )"; 113 /// The code snippet to initialize the sizes for the value range calculation. 114 /// 115 /// {0}: The code to get the attribute. 116 static const char *const opSegmentSizeAttrInitCode = R"( 117 auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); 118 )"; 119 120 /// The logic to calculate the actual value range for a declared operand 121 /// of an op with variadic of variadic operands within the OpAdaptor. 122 /// 123 /// {0}: The name of the segment attribute. 124 /// {1}: The index of the main operand. 125 static const char *const variadicOfVariadicAdaptorCalcCode = R"( 126 auto tblgenTmpOperands = getODSOperands({1}); 127 auto sizeAttrValues = {0}().getValues<uint32_t>(); 128 auto sizeAttrIt = sizeAttrValues.begin(); 129 130 ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups; 131 for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{ 132 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt)); 133 tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt); 134 } 135 return tblgenTmpOperandGroups; 136 )"; 137 138 /// The logic to build a range of either operand or result values. 139 /// 140 /// {0}: The begin iterator of the actual values. 141 /// {1}: The call to generate the start and length of the value range. 142 static const char *const valueRangeReturnCode = R"( 143 auto valueRange = {1}; 144 return {{std::next({0}, valueRange.first), 145 std::next({0}, valueRange.first + valueRange.second)}; 146 )"; 147 148 /// A header for indicating code sections. 149 /// 150 /// {0}: Some text, or a class name. 151 /// {1}: Some text. 152 static const char *const opCommentHeader = R"( 153 //===----------------------------------------------------------------------===// 154 // {0} {1} 155 //===----------------------------------------------------------------------===// 156 157 )"; 158 159 //===----------------------------------------------------------------------===// 160 // Utility structs and functions 161 //===----------------------------------------------------------------------===// 162 163 // Replaces all occurrences of `match` in `str` with `substitute`. 164 static std::string replaceAllSubstrs(std::string str, const std::string &match, 165 const std::string &substitute) { 166 std::string::size_type scanLoc = 0, matchLoc = std::string::npos; 167 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) { 168 str = str.replace(matchLoc, match.size(), substitute); 169 scanLoc = matchLoc + substitute.size(); 170 } 171 return str; 172 } 173 174 // Returns whether the record has a value of the given name that can be returned 175 // via getValueAsString. 176 static inline bool hasStringAttribute(const Record &record, 177 StringRef fieldName) { 178 auto *valueInit = record.getValueInit(fieldName); 179 return isa<StringInit>(valueInit); 180 } 181 182 static std::string getArgumentName(const Operator &op, int index) { 183 const auto &operand = op.getOperand(index); 184 if (!operand.name.empty()) 185 return std::string(operand.name); 186 return std::string(formatv("{0}_{1}", generatedArgName, index)); 187 } 188 189 // Returns true if we can use unwrapped value for the given `attr` in builders. 190 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { 191 return attr.getReturnType() != attr.getStorageType() && 192 // We need to wrap the raw value into an attribute in the builder impl 193 // so we need to make sure that the attribute specifies how to do that. 194 !attr.getConstBuilderTemplate().empty(); 195 } 196 197 namespace { 198 /// Metadata on a registered attribute. Given that attributes are stored in 199 /// sorted order on operations, we can use information from ODS to deduce the 200 /// number of required attributes less and and greater than each attribute, 201 /// allowing us to search only a subrange of the attributes in ODS-generated 202 /// getters. 203 struct AttributeMetadata { 204 /// The attribute name. 205 StringRef attrName; 206 /// Whether the attribute is required. 207 bool isRequired; 208 /// The ODS attribute constraint. Not present for implicit attributes. 209 Optional<Attribute> constraint; 210 /// The number of required attributes less than this attribute. 211 unsigned lowerBound = 0; 212 /// The number of required attributes greater than this attribute. 213 unsigned upperBound = 0; 214 }; 215 216 /// Helper class to select between OpAdaptor and Op code templates. 217 class OpOrAdaptorHelper { 218 public: 219 OpOrAdaptorHelper(const Operator &op, bool emitForOp) 220 : op(op), emitForOp(emitForOp) { 221 computeAttrMetadata(); 222 } 223 224 /// Object that wraps a functor in a stream operator for interop with 225 /// llvm::formatv. 226 class Formatter { 227 public: 228 template <typename Functor> 229 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {} 230 231 std::string str() const { 232 std::string result; 233 llvm::raw_string_ostream os(result); 234 os << *this; 235 return os.str(); 236 } 237 238 private: 239 std::function<raw_ostream &(raw_ostream &)> func; 240 241 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) { 242 return fmt.func(os); 243 } 244 }; 245 246 // Generate code for getting an attribute. 247 Formatter getAttr(StringRef attrName, bool isNamed = false) const { 248 assert(attrMetadata.count(attrName) && "expected attribute metadata"); 249 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & { 250 const AttributeMetadata &attr = attrMetadata.find(attrName)->second; 251 return os << formatv(subrangeGetAttr, getAttrName(attrName), 252 attr.lowerBound, attr.upperBound, getAttrRange(), 253 isNamed ? "Named" : ""); 254 }; 255 } 256 257 // Generate code for getting the name of an attribute. 258 Formatter getAttrName(StringRef attrName) const { 259 return [this, attrName](raw_ostream &os) -> raw_ostream & { 260 if (emitForOp) 261 return os << op.getGetterName(attrName) << "AttrName()"; 262 return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(), 263 op.getGetterName(attrName)); 264 }; 265 } 266 267 // Get the code snippet for getting the named attribute range. 268 StringRef getAttrRange() const { 269 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs"; 270 } 271 272 // Get the prefix code for emitting an error. 273 Formatter emitErrorPrefix() const { 274 return [this](raw_ostream &os) -> raw_ostream & { 275 if (emitForOp) 276 return os << "emitOpError("; 277 return os << formatv("emitError(loc, \"'{0}' op \"", 278 op.getOperationName()); 279 }; 280 } 281 282 // Get the call to get an operand or segment of operands. 283 Formatter getOperand(unsigned index) const { 284 return [this, index](raw_ostream &os) -> raw_ostream & { 285 return os << formatv(op.getOperand(index).isVariadic() 286 ? "this->getODSOperands({0})" 287 : "(*this->getODSOperands({0}).begin())", 288 index); 289 }; 290 } 291 292 // Get the call to get a result of segment of results. 293 Formatter getResult(unsigned index) const { 294 return [this, index](raw_ostream &os) -> raw_ostream & { 295 if (!emitForOp) 296 return os << "<no results should be generated>"; 297 return os << formatv(op.getResult(index).isVariadic() 298 ? "this->getODSResults({0})" 299 : "(*this->getODSResults({0}).begin())", 300 index); 301 }; 302 } 303 304 // Return whether an op instance is available. 305 bool isEmittingForOp() const { return emitForOp; } 306 307 // Return the ODS operation wrapper. 308 const Operator &getOp() const { return op; } 309 310 // Get the attribute metadata sorted by name. 311 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const { 312 return attrMetadata; 313 } 314 315 private: 316 // Compute the attribute metadata. 317 void computeAttrMetadata(); 318 319 // The operation ODS wrapper. 320 const Operator &op; 321 // True if code is being generate for an op. False for an adaptor. 322 const bool emitForOp; 323 324 // The attribute metadata, mapped by name. 325 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata; 326 // The number of required attributes. 327 unsigned numRequired; 328 }; 329 330 } // namespace 331 332 void OpOrAdaptorHelper::computeAttrMetadata() { 333 // Enumerate the attribute names of this op, ensuring the attribute names are 334 // unique in case implicit attributes are explicitly registered. 335 for (const NamedAttribute &namedAttr : op.getAttributes()) { 336 Attribute attr = namedAttr.attr; 337 bool isOptional = 338 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr(); 339 attrMetadata.insert( 340 {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); 341 } 342 // Include key attributes from several traits as implicitly registered. 343 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 344 attrMetadata.insert( 345 {operandSegmentAttrName, 346 AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, 347 /*attr=*/llvm::None}}); 348 } 349 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 350 attrMetadata.insert( 351 {resultSegmentAttrName, 352 AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, 353 /*attr=*/llvm::None}}); 354 } 355 356 // Store the metadata in sorted order. 357 SmallVector<AttributeMetadata> sortedAttrMetadata = 358 llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector())); 359 llvm::sort(sortedAttrMetadata, 360 [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) { 361 return lhs.attrName < rhs.attrName; 362 }); 363 364 // Compute the subrange bounds for each attribute. 365 numRequired = 0; 366 for (AttributeMetadata &attr : sortedAttrMetadata) { 367 attr.lowerBound = numRequired; 368 numRequired += attr.isRequired; 369 }; 370 for (AttributeMetadata &attr : sortedAttrMetadata) 371 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired; 372 373 // Store the results back into the map. 374 for (const AttributeMetadata &attr : sortedAttrMetadata) 375 attrMetadata.insert({attr.attrName, attr}); 376 } 377 378 //===----------------------------------------------------------------------===// 379 // Op emitter 380 //===----------------------------------------------------------------------===// 381 382 namespace { 383 // Helper class to emit a record into the given output stream. 384 class OpEmitter { 385 public: 386 static void 387 emitDecl(const Operator &op, raw_ostream &os, 388 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 389 static void 390 emitDef(const Operator &op, raw_ostream &os, 391 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 392 393 private: 394 OpEmitter(const Operator &op, 395 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 396 397 void emitDecl(raw_ostream &os); 398 void emitDef(raw_ostream &os); 399 400 // Generate methods for accessing the attribute names of this operation. 401 void genAttrNameGetters(); 402 403 // Generates the OpAsmOpInterface for this operation if possible. 404 void genOpAsmInterface(); 405 406 // Generates the `getOperationName` method for this op. 407 void genOpNameGetter(); 408 409 // Generates getters for the attributes. 410 void genAttrGetters(); 411 412 // Generates setter for the attributes. 413 void genAttrSetters(); 414 415 // Generates removers for optional attributes. 416 void genOptionalAttrRemovers(); 417 418 // Generates getters for named operands. 419 void genNamedOperandGetters(); 420 421 // Generates setters for named operands. 422 void genNamedOperandSetters(); 423 424 // Generates getters for named results. 425 void genNamedResultGetters(); 426 427 // Generates getters for named regions. 428 void genNamedRegionGetters(); 429 430 // Generates getters for named successors. 431 void genNamedSuccessorGetters(); 432 433 // Generates builder methods for the operation. 434 void genBuilder(); 435 436 // Generates the build() method that takes each operand/attribute 437 // as a stand-alone parameter. 438 void genSeparateArgParamBuilder(); 439 440 // Generates the build() method that takes each operand/attribute as a 441 // stand-alone parameter. The generated build() method uses first operand's 442 // type as all results' types. 443 void genUseOperandAsResultTypeSeparateParamBuilder(); 444 445 // Generates the build() method that takes all operands/attributes 446 // collectively as one parameter. The generated build() method uses first 447 // operand's type as all results' types. 448 void genUseOperandAsResultTypeCollectiveParamBuilder(); 449 450 // Generates the build() method that takes aggregate operands/attributes 451 // parameters. This build() method uses inferred types as result types. 452 // Requires: The type needs to be inferable via InferTypeOpInterface. 453 void genInferredTypeCollectiveParamBuilder(); 454 455 // Generates the build() method that takes each operand/attribute as a 456 // stand-alone parameter. The generated build() method uses first attribute's 457 // type as all result's types. 458 void genUseAttrAsResultTypeBuilder(); 459 460 // Generates the build() method that takes all result types collectively as 461 // one parameter. Similarly for operands and attributes. 462 void genCollectiveParamBuilder(); 463 464 // The kind of parameter to generate for result types in builders. 465 enum class TypeParamKind { 466 None, // No result type in parameter list. 467 Separate, // A separate parameter for each result type. 468 Collective, // An ArrayRef<Type> for all result types. 469 }; 470 471 // The kind of parameter to generate for attributes in builders. 472 enum class AttrParamKind { 473 WrappedAttr, // A wrapped MLIR Attribute instance. 474 UnwrappedValue, // A raw value without MLIR Attribute wrapper. 475 }; 476 477 // Builds the parameter list for build() method of this op. This method writes 478 // to `paramList` the comma-separated parameter list and updates 479 // `resultTypeNames` with the names for parameters for specifying result 480 // types. `inferredAttributes` is populated with any attributes that are 481 // elided from the build list. The given `typeParamKind` and `attrParamKind` 482 // controls how result types and attributes are placed in the parameter list. 483 void buildParamList(SmallVectorImpl<MethodParameter> ¶mList, 484 llvm::StringSet<> &inferredAttributes, 485 SmallVectorImpl<std::string> &resultTypeNames, 486 TypeParamKind typeParamKind, 487 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); 488 489 // Adds op arguments and regions into operation state for build() methods. 490 void 491 genCodeForAddingArgAndRegionForBuilder(MethodBody &body, 492 llvm::StringSet<> &inferredAttributes, 493 bool isRawValueAttr = false); 494 495 // Generates canonicalizer declaration for the operation. 496 void genCanonicalizerDecls(); 497 498 // Generates the folder declaration for the operation. 499 void genFolderDecls(); 500 501 // Generates the parser for the operation. 502 void genParser(); 503 504 // Generates the printer for the operation. 505 void genPrinter(); 506 507 // Generates verify method for the operation. 508 void genVerifier(); 509 510 // Generates custom verify methods for the operation. 511 void genCustomVerifier(); 512 513 // Generates verify statements for operands and results in the operation. 514 // The generated code will be attached to `body`. 515 void genOperandResultVerifier(MethodBody &body, 516 Operator::const_value_range values, 517 StringRef valueKind); 518 519 // Generates verify statements for regions in the operation. 520 // The generated code will be attached to `body`. 521 void genRegionVerifier(MethodBody &body); 522 523 // Generates verify statements for successors in the operation. 524 // The generated code will be attached to `body`. 525 void genSuccessorVerifier(MethodBody &body); 526 527 // Generates the traits used by the object. 528 void genTraits(); 529 530 // Generate the OpInterface methods for all interfaces. 531 void genOpInterfaceMethods(); 532 533 // Generate op interface methods for the given interface. 534 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait); 535 536 // Generate op interface method for the given interface method. If 537 // 'declaration' is true, generates a declaration, else a definition. 538 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, 539 bool declaration = true); 540 541 // Generate the side effect interface methods. 542 void genSideEffectInterfaceMethods(); 543 544 // Generate the type inference interface methods. 545 void genTypeInterfaceMethods(); 546 547 private: 548 // The TableGen record for this op. 549 // TODO: OpEmitter should not have a Record directly, 550 // it should rather go through the Operator for better abstraction. 551 const Record &def; 552 553 // The wrapper operator class for querying information from this op. 554 const Operator &op; 555 556 // The C++ code builder for this op 557 OpClass opClass; 558 559 // The format context for verification code generation. 560 FmtContext verifyCtx; 561 562 // The emitter containing all of the locally emitted verification functions. 563 const StaticVerifierFunctionEmitter &staticVerifierEmitter; 564 565 // Helper for emitting op code. 566 OpOrAdaptorHelper emitHelper; 567 }; 568 569 } // namespace 570 571 // Populate the format context `ctx` with substitutions of attributes, operands 572 // and results. 573 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper, 574 FmtContext &ctx) { 575 // Populate substitutions for attributes. 576 auto &op = emitHelper.getOp(); 577 for (const auto &namedAttr : op.getAttributes()) 578 ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str()); 579 580 // Populate substitutions for named operands. 581 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 582 auto &value = op.getOperand(i); 583 if (!value.name.empty()) 584 ctx.addSubst(value.name, emitHelper.getOperand(i).str()); 585 } 586 587 // Populate substitutions for results. 588 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 589 auto &value = op.getResult(i); 590 if (!value.name.empty()) 591 ctx.addSubst(value.name, emitHelper.getResult(i).str()); 592 } 593 } 594 595 /// Generate verification on native traits requiring attributes. 596 static void genNativeTraitAttrVerifier(MethodBody &body, 597 const OpOrAdaptorHelper &emitHelper) { 598 // Check that the variadic segment sizes attribute exists and contains the 599 // expected number of elements. 600 // 601 // {0}: Attribute name. 602 // {1}: Expected number of elements. 603 // {2}: "operand" or "result". 604 // {3}: Emit error prefix. 605 const char *const checkAttrSizedValueSegmentsCode = R"( 606 { 607 auto sizeAttr = tblgen_{0}.cast<::mlir::DenseIntElementsAttr>(); 608 auto numElements = 609 sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); 610 if (numElements != {1}) 611 return {3}"'{0}' attribute for specifying {2} segments must have {1} " 612 "elements, but got ") << numElements; 613 } 614 )"; 615 616 // Verify a few traits first so that we can use getODSOperands() and 617 // getODSResults() in the rest of the verifier. 618 auto &op = emitHelper.getOp(); 619 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 620 body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, 621 op.getNumOperands(), "operand", 622 emitHelper.emitErrorPrefix()); 623 } 624 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { 625 body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, 626 op.getNumResults(), "result", emitHelper.emitErrorPrefix()); 627 } 628 } 629 630 // Generate attribute verification. If an op instance is not available, then 631 // attribute checks that require one will not be emitted. 632 // 633 // Attribute verification is performed as follows: 634 // 635 // 1. Verify that all required attributes are present in sorted order. This 636 // ensures that we can use subrange lookup even with potentially missing 637 // attributes. 638 // 2. Verify native trait attributes so that other attributes may call methods 639 // that depend on the validity of these attributes, e.g. segment size attributes 640 // and operand or result getters. 641 // 3. Verify the constraints on all present attributes. 642 static void genAttributeVerifier( 643 const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body, 644 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 645 if (emitHelper.getAttrMetadata().empty()) 646 return; 647 648 // Verify the attribute if it is present. This assumes that default values 649 // are valid. This code snippet pastes the condition inline. 650 // 651 // TODO: verify the default value is valid (perhaps in debug mode only). 652 // 653 // {0}: Attribute variable name. 654 // {1}: Attribute condition code. 655 // {2}: Emit error prefix. 656 // {3}: Attribute name. 657 // {4}: Attribute/constraint description. 658 const char *const verifyAttrInline = R"( 659 if ({0} && !({1})) 660 return {2}"attribute '{3}' failed to satisfy constraint: {4}"); 661 )"; 662 // Verify the attribute using a uniqued constraint. Can only be used within 663 // the context of an op. 664 // 665 // {0}: Unique constraint name. 666 // {1}: Attribute variable name. 667 // {2}: Attribute name. 668 const char *const verifyAttrUnique = R"( 669 if (::mlir::failed({0}(*this, {1}, "{2}"))) 670 return ::mlir::failure(); 671 )"; 672 673 // Traverse the array until the required attribute is found. Return an error 674 // if the traversal reached the end. 675 // 676 // {0}: Code to get the name of the attribute. 677 // {1}: The emit error prefix. 678 // {2}: The name of the attribute. 679 const char *const findRequiredAttr = R"(while (true) {{ 680 if (namedAttrIt == namedAttrRange.end()) 681 return {1}"requires attribute '{2}'"); 682 if (namedAttrIt->getName() == {0}) {{ 683 tblgen_{2} = namedAttrIt->getValue(); 684 break; 685 })"; 686 687 // Emit a check to see if the iteration has encountered an optional attribute. 688 // 689 // {0}: Code to get the name of the attribute. 690 // {1}: The name of the attribute. 691 const char *const checkOptionalAttr = R"( 692 else if (namedAttrIt->getName() == {0}) {{ 693 tblgen_{1} = namedAttrIt->getValue(); 694 })"; 695 696 // Emit the start of the loop for checking trailing attributes. 697 const char *const checkTrailingAttrs = R"(while (true) { 698 if (namedAttrIt == namedAttrRange.end()) { 699 break; 700 })"; 701 702 // Return true if a verifier can be emitted for the attribute: it is not a 703 // derived attribute, it has a predicate, its condition is not empty, and, for 704 // adaptors, the condition does not reference the op. 705 const auto canEmitVerifier = [&](Attribute attr) { 706 if (attr.isDerivedAttr()) 707 return false; 708 Pred pred = attr.getPredicate(); 709 if (pred.isNull()) 710 return false; 711 std::string condition = pred.getCondition(); 712 return !condition.empty() && (!StringRef(condition).contains("$_op") || 713 emitHelper.isEmittingForOp()); 714 }; 715 716 // Emit the verifier for the attribute. 717 const auto emitVerifier = [&](Attribute attr, StringRef attrName, 718 StringRef varName) { 719 std::string condition = attr.getPredicate().getCondition(); 720 721 Optional<StringRef> constraintFn; 722 if (emitHelper.isEmittingForOp() && 723 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { 724 body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); 725 } else { 726 body << formatv(verifyAttrInline, varName, 727 tgfmt(condition, &ctx.withSelf(varName)), 728 emitHelper.emitErrorPrefix(), attrName, 729 escapeString(attr.getSummary())); 730 } 731 }; 732 733 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor. 734 const auto getVarName = [&](StringRef attrName) { 735 return (tblgenNamePrefix + attrName).str(); 736 }; 737 738 body.indent() << formatv("auto namedAttrRange = {0};\n", 739 emitHelper.getAttrRange()); 740 body << "auto namedAttrIt = namedAttrRange.begin();\n"; 741 742 // Iterate over the attributes in sorted order. Keep track of the optional 743 // attributes that may be encountered along the way. 744 SmallVector<const AttributeMetadata *> optionalAttrs; 745 for (const std::pair<StringRef, AttributeMetadata> &it : 746 emitHelper.getAttrMetadata()) { 747 const AttributeMetadata &metadata = it.second; 748 if (!metadata.isRequired) { 749 optionalAttrs.push_back(&metadata); 750 continue; 751 } 752 753 body << formatv("::mlir::Attribute {0};\n", getVarName(it.first)); 754 for (const AttributeMetadata *optional : optionalAttrs) { 755 body << formatv("::mlir::Attribute {0};\n", 756 getVarName(optional->attrName)); 757 } 758 body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first), 759 emitHelper.emitErrorPrefix(), it.first); 760 for (const AttributeMetadata *optional : optionalAttrs) { 761 body << formatv(checkOptionalAttr, 762 emitHelper.getAttrName(optional->attrName), 763 optional->attrName); 764 } 765 body << "\n ++namedAttrIt;\n}\n"; 766 optionalAttrs.clear(); 767 } 768 // Get trailing optional attributes. 769 if (!optionalAttrs.empty()) { 770 for (const AttributeMetadata *optional : optionalAttrs) { 771 body << formatv("::mlir::Attribute {0};\n", 772 getVarName(optional->attrName)); 773 } 774 body << checkTrailingAttrs; 775 for (const AttributeMetadata *optional : optionalAttrs) { 776 body << formatv(checkOptionalAttr, 777 emitHelper.getAttrName(optional->attrName), 778 optional->attrName); 779 } 780 body << "\n ++namedAttrIt;\n}\n"; 781 } 782 body.unindent(); 783 784 // Emit the checks for segment attributes first so that the other constraints 785 // can call operand and result getters. 786 genNativeTraitAttrVerifier(body, emitHelper); 787 788 for (const auto &namedAttr : emitHelper.getOp().getAttributes()) 789 if (canEmitVerifier(namedAttr.attr)) 790 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); 791 } 792 793 /// Op extra class definitions have a `$cppClass` substitution that is to be 794 /// replaced by the C++ class name. 795 static std::string formatExtraDefinitions(const Operator &op) { 796 FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); 797 return tgfmt(op.getExtraClassDefinition(), &ctx).str(); 798 } 799 800 OpEmitter::OpEmitter(const Operator &op, 801 const StaticVerifierFunctionEmitter &staticVerifierEmitter) 802 : def(op.getDef()), op(op), 803 opClass(op.getCppClassName(), op.getExtraClassDeclaration(), 804 formatExtraDefinitions(op)), 805 staticVerifierEmitter(staticVerifierEmitter), 806 emitHelper(op, /*emitForOp=*/true) { 807 verifyCtx.withOp("(*this->getOperation())"); 808 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); 809 810 genTraits(); 811 812 // Generate C++ code for various op methods. The order here determines the 813 // methods in the generated file. 814 genAttrNameGetters(); 815 genOpAsmInterface(); 816 genOpNameGetter(); 817 genNamedOperandGetters(); 818 genNamedOperandSetters(); 819 genNamedResultGetters(); 820 genNamedRegionGetters(); 821 genNamedSuccessorGetters(); 822 genAttrGetters(); 823 genAttrSetters(); 824 genOptionalAttrRemovers(); 825 genBuilder(); 826 genParser(); 827 genPrinter(); 828 genVerifier(); 829 genCustomVerifier(); 830 genCanonicalizerDecls(); 831 genFolderDecls(); 832 genTypeInterfaceMethods(); 833 genOpInterfaceMethods(); 834 generateOpFormat(op, opClass); 835 genSideEffectInterfaceMethods(); 836 } 837 void OpEmitter::emitDecl( 838 const Operator &op, raw_ostream &os, 839 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 840 OpEmitter(op, staticVerifierEmitter).emitDecl(os); 841 } 842 843 void OpEmitter::emitDef( 844 const Operator &op, raw_ostream &os, 845 const StaticVerifierFunctionEmitter &staticVerifierEmitter) { 846 OpEmitter(op, staticVerifierEmitter).emitDef(os); 847 } 848 849 void OpEmitter::emitDecl(raw_ostream &os) { 850 opClass.finalize(); 851 opClass.writeDeclTo(os); 852 } 853 854 void OpEmitter::emitDef(raw_ostream &os) { 855 opClass.finalize(); 856 opClass.writeDefTo(os); 857 } 858 859 static void errorIfPruned(size_t line, Method *m, const Twine &methodName, 860 const Operator &op) { 861 if (m) 862 return; 863 PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" + 864 methodName + "` for " + 865 op.getOperationName() + " (from line " + 866 Twine(line) + ")"); 867 } 868 869 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) 870 871 void OpEmitter::genAttrNameGetters() { 872 const llvm::MapVector<StringRef, AttributeMetadata> &attributes = 873 emitHelper.getAttrMetadata(); 874 875 // Emit the getAttributeNames method. 876 { 877 auto *method = opClass.addStaticInlineMethod( 878 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); 879 ERROR_IF_PRUNED(method, "getAttributeNames", op); 880 auto &body = method->body(); 881 if (attributes.empty()) { 882 body << " return {};"; 883 // Nothing else to do if there are no registered attributes. Exit early. 884 return; 885 } 886 body << " static ::llvm::StringRef attrNames[] = {"; 887 llvm::interleaveComma(llvm::make_first_range(attributes), body, 888 [&](StringRef attrName) { 889 body << "::llvm::StringRef(\"" << attrName << "\")"; 890 }); 891 body << "};\n return ::llvm::makeArrayRef(attrNames);"; 892 } 893 894 // Emit the getAttributeNameForIndex methods. 895 { 896 auto *method = opClass.addInlineMethod<Method::Private>( 897 "::mlir::StringAttr", "getAttributeNameForIndex", 898 MethodParameter("unsigned", "index")); 899 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); 900 method->body() 901 << " return getAttributeNameForIndex((*this)->getName(), index);"; 902 } 903 { 904 auto *method = opClass.addStaticInlineMethod<Method::Private>( 905 "::mlir::StringAttr", "getAttributeNameForIndex", 906 MethodParameter("::mlir::OperationName", "name"), 907 MethodParameter("unsigned", "index")); 908 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); 909 910 const char *const getAttrName = R"( 911 assert(index < {0} && "invalid attribute index"); 912 return name.getRegisteredInfo()->getAttributeNames()[index]; 913 )"; 914 method->body() << formatv(getAttrName, attributes.size()); 915 } 916 917 // Generate the <attr>AttrName methods, that expose the attribute names to 918 // users. 919 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; 920 for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) { 921 for (StringRef name : op.getGetterNames(attrIt.value())) { 922 std::string methodName = (name + "AttrName").str(); 923 924 // Generate the non-static variant. 925 { 926 auto *method = 927 opClass.addInlineMethod("::mlir::StringAttr", methodName); 928 ERROR_IF_PRUNED(method, methodName, op); 929 method->body() << llvm::formatv(attrNameMethodBody, attrIt.index()); 930 } 931 932 // Generate the static variant. 933 { 934 auto *method = opClass.addStaticInlineMethod( 935 "::mlir::StringAttr", methodName, 936 MethodParameter("::mlir::OperationName", "name")); 937 ERROR_IF_PRUNED(method, methodName, op); 938 method->body() << llvm::formatv(attrNameMethodBody, 939 "name, " + Twine(attrIt.index())); 940 } 941 } 942 } 943 } 944 945 // Emit the getter for an attribute with the return type specified. 946 // It is templated to be shared between the Op and the adaptor class. 947 template <typename OpClassOrAdaptor> 948 static void emitAttrGetterWithReturnType(FmtContext &fctx, 949 OpClassOrAdaptor &opClass, 950 const Operator &op, StringRef name, 951 Attribute attr) { 952 auto *method = opClass.addMethod(attr.getReturnType(), name); 953 ERROR_IF_PRUNED(method, name, op); 954 auto &body = method->body(); 955 body << " auto attr = " << name << "Attr();\n"; 956 if (attr.hasDefaultValue()) { 957 // Returns the default value if not set. 958 // TODO: this is inefficient, we are recreating the attribute for every 959 // call. This should be set instead. 960 if (!attr.isConstBuildable()) { 961 PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() + 962 " must have a constBuilder"); 963 } 964 std::string defaultValue = std::string( 965 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 966 body << " if (!attr)\n return " 967 << tgfmt(attr.getConvertFromStorageCall(), 968 &fctx.withSelf(defaultValue)) 969 << ";\n"; 970 } 971 body << " return " 972 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) 973 << ";\n"; 974 } 975 976 void OpEmitter::genAttrGetters() { 977 FmtContext fctx; 978 fctx.withBuilder("::mlir::Builder((*this)->getContext())"); 979 980 // Emit the derived attribute body. 981 auto emitDerivedAttr = [&](StringRef name, Attribute attr) { 982 if (auto *method = opClass.addMethod(attr.getReturnType(), name)) 983 method->body() << " " << attr.getDerivedCodeBody() << "\n"; 984 }; 985 986 // Generate named accessor with Attribute return type. This is a wrapper class 987 // that allows referring to the attributes via accessors instead of having to 988 // use the string interface for better compile time verification. 989 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName, 990 Attribute attr) { 991 auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr"); 992 if (!method) 993 return; 994 method->body() << formatv( 995 " return {0}.{1}<{2}>();", emitHelper.getAttr(attrName), 996 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" 997 : "cast", 998 attr.getStorageType()); 999 }; 1000 1001 for (const NamedAttribute &namedAttr : op.getAttributes()) { 1002 for (StringRef name : op.getGetterNames(namedAttr.name)) { 1003 if (namedAttr.attr.isDerivedAttr()) { 1004 emitDerivedAttr(name, namedAttr.attr); 1005 } else { 1006 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); 1007 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); 1008 } 1009 } 1010 } 1011 1012 auto derivedAttrs = make_filter_range(op.getAttributes(), 1013 [](const NamedAttribute &namedAttr) { 1014 return namedAttr.attr.isDerivedAttr(); 1015 }); 1016 if (derivedAttrs.empty()) 1017 return; 1018 1019 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); 1020 // Generate helper method to query whether a named attribute is a derived 1021 // attribute. This enables, for example, avoiding adding an attribute that 1022 // overlaps with a derived attribute. 1023 { 1024 auto *method = 1025 opClass.addStaticMethod("bool", "isDerivedAttribute", 1026 MethodParameter("::llvm::StringRef", "name")); 1027 ERROR_IF_PRUNED(method, "isDerivedAttribute", op); 1028 auto &body = method->body(); 1029 for (auto namedAttr : derivedAttrs) 1030 body << " if (name == \"" << namedAttr.name << "\") return true;\n"; 1031 body << " return false;"; 1032 } 1033 // Generate method to materialize derived attributes as a DictionaryAttr. 1034 { 1035 auto *method = opClass.addMethod("::mlir::DictionaryAttr", 1036 "materializeDerivedAttributes"); 1037 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); 1038 auto &body = method->body(); 1039 1040 auto nonMaterializable = 1041 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { 1042 return namedAttr.attr.getConvertFromStorageCall().empty(); 1043 }); 1044 if (!nonMaterializable.empty()) { 1045 std::string attrs; 1046 llvm::raw_string_ostream os(attrs); 1047 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { 1048 os << op.getGetterName(attr.name); 1049 }); 1050 PrintWarning( 1051 op.getLoc(), 1052 formatv( 1053 "op has non-materializable derived attributes '{0}', skipping", 1054 os.str())); 1055 body << formatv(" emitOpError(\"op has non-materializable derived " 1056 "attributes '{0}'\");\n", 1057 attrs); 1058 body << " return nullptr;"; 1059 return; 1060 } 1061 1062 body << " ::mlir::MLIRContext* ctx = getContext();\n"; 1063 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; 1064 body << " return ::mlir::DictionaryAttr::get("; 1065 body << " ctx, {\n"; 1066 interleave( 1067 derivedAttrs, body, 1068 [&](const NamedAttribute &namedAttr) { 1069 auto tmpl = namedAttr.attr.getConvertFromStorageCall(); 1070 std::string name = op.getGetterName(namedAttr.name); 1071 body << " {" << name << "AttrName(),\n" 1072 << tgfmt(tmpl, &fctx.withSelf(name + "()") 1073 .withBuilder("odsBuilder") 1074 .addSubst("_ctx", "ctx")) 1075 << "}"; 1076 }, 1077 ",\n"); 1078 body << "});"; 1079 } 1080 } 1081 1082 void OpEmitter::genAttrSetters() { 1083 // Generate raw named setter type. This is a wrapper class that allows setting 1084 // to the attributes via setters instead of having to use the string interface 1085 // for better compile time verification. 1086 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName, 1087 Attribute attr) { 1088 auto *method = 1089 opClass.addMethod("void", setterName + "Attr", 1090 MethodParameter(attr.getStorageType(), "attr")); 1091 if (method) 1092 method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);", 1093 getterName); 1094 }; 1095 1096 for (const NamedAttribute &namedAttr : op.getAttributes()) { 1097 if (namedAttr.attr.isDerivedAttr()) 1098 continue; 1099 for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), 1100 op.getGetterNames(namedAttr.name))) 1101 emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), 1102 namedAttr.attr); 1103 } 1104 } 1105 1106 void OpEmitter::genOptionalAttrRemovers() { 1107 // Generate methods for removing optional attributes, instead of having to 1108 // use the string interface. Enables better compile time verification. 1109 auto emitRemoveAttr = [&](StringRef name) { 1110 auto upperInitial = name.take_front().upper(); 1111 auto suffix = name.drop_front(); 1112 auto *method = opClass.addMethod("::mlir::Attribute", 1113 "remove" + upperInitial + suffix + "Attr"); 1114 if (!method) 1115 return; 1116 method->body() << formatv(" return (*this)->removeAttr({0}AttrName());", 1117 op.getGetterName(name)); 1118 }; 1119 1120 for (const NamedAttribute &namedAttr : op.getAttributes()) 1121 if (namedAttr.attr.isOptional()) 1122 emitRemoveAttr(namedAttr.name); 1123 } 1124 1125 // Generates the code to compute the start and end index of an operand or result 1126 // range. 1127 template <typename RangeT> 1128 static void 1129 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, 1130 int numVariadic, int numNonVariadic, 1131 StringRef rangeSizeCall, bool hasAttrSegmentSize, 1132 StringRef sizeAttrInit, RangeT &&odsValues) { 1133 auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName, 1134 MethodParameter("unsigned", "index")); 1135 if (!method) 1136 return; 1137 auto &body = method->body(); 1138 if (numVariadic == 0) { 1139 body << " return {index, 1};\n"; 1140 } else if (hasAttrSegmentSize) { 1141 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; 1142 } else { 1143 // Because the op can have arbitrarily interleaved variadic and non-variadic 1144 // operands, we need to embed a list in the "sink" getter method for 1145 // calculation at run-time. 1146 SmallVector<StringRef, 4> isVariadic; 1147 isVariadic.reserve(llvm::size(odsValues)); 1148 for (auto &it : odsValues) 1149 isVariadic.push_back(it.isVariableLength() ? "true" : "false"); 1150 std::string isVariadicList = llvm::join(isVariadic, ", "); 1151 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, 1152 numNonVariadic, numVariadic, rangeSizeCall, "operand"); 1153 } 1154 } 1155 1156 // Generates the named operand getter methods for the given Operator `op` and 1157 // puts them in `opClass`. Uses `rangeType` as the return type of getters that 1158 // return a range of operands (individual operands are `Value ` and each 1159 // element in the range must also be `Value `); use `rangeBeginCall` to get 1160 // an iterator to the beginning of the operand range; use `rangeSizeCall` to 1161 // obtain the number of operands. `getOperandCallPattern` contains the code 1162 // necessary to obtain a single operand whose position will be substituted 1163 // instead of 1164 // "{0}" marker in the pattern. Note that the pattern should work for any kind 1165 // of ops, in particular for one-operand ops that may not have the 1166 // `getOperand(unsigned)` method. 1167 static void generateNamedOperandGetters(const Operator &op, Class &opClass, 1168 bool isAdaptor, StringRef sizeAttrInit, 1169 StringRef rangeType, 1170 StringRef rangeBeginCall, 1171 StringRef rangeSizeCall, 1172 StringRef getOperandCallPattern) { 1173 const int numOperands = op.getNumOperands(); 1174 const int numVariadicOperands = op.getNumVariableLengthOperands(); 1175 const int numNormalOperands = numOperands - numVariadicOperands; 1176 1177 const auto *sameVariadicSize = 1178 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); 1179 const auto *attrSizedOperands = 1180 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 1181 1182 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { 1183 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " 1184 "specification over their sizes"); 1185 } 1186 1187 if (numVariadicOperands < 2 && attrSizedOperands) { 1188 PrintFatalError(op.getLoc(), "op must have at least two variadic operands " 1189 "to use 'AttrSizedOperandSegments' trait"); 1190 } 1191 1192 if (attrSizedOperands && sameVariadicSize) { 1193 PrintFatalError(op.getLoc(), 1194 "op cannot have both 'AttrSizedOperandSegments' and " 1195 "'SameVariadicOperandSize' traits"); 1196 } 1197 1198 // First emit a few "sink" getter methods upon which we layer all nicer named 1199 // getter methods. 1200 generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength", 1201 numVariadicOperands, numNormalOperands, 1202 rangeSizeCall, attrSizedOperands, sizeAttrInit, 1203 const_cast<Operator &>(op).getOperands()); 1204 1205 auto *m = opClass.addMethod(rangeType, "getODSOperands", 1206 MethodParameter("unsigned", "index")); 1207 ERROR_IF_PRUNED(m, "getODSOperands", op); 1208 auto &body = m->body(); 1209 body << formatv(valueRangeReturnCode, rangeBeginCall, 1210 "getODSOperandIndexAndLength(index)"); 1211 1212 // Then we emit nicer named getter methods by redirecting to the "sink" getter 1213 // method. 1214 for (int i = 0; i != numOperands; ++i) { 1215 const auto &operand = op.getOperand(i); 1216 if (operand.name.empty()) 1217 continue; 1218 for (StringRef name : op.getGetterNames(operand.name)) { 1219 if (operand.isOptional()) { 1220 m = opClass.addMethod("::mlir::Value", name); 1221 ERROR_IF_PRUNED(m, name, op); 1222 m->body() << " auto operands = getODSOperands(" << i << ");\n" 1223 << " return operands.empty() ? ::mlir::Value() : " 1224 "*operands.begin();"; 1225 } else if (operand.isVariadicOfVariadic()) { 1226 std::string segmentAttr = op.getGetterName( 1227 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); 1228 if (isAdaptor) { 1229 m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", 1230 name); 1231 ERROR_IF_PRUNED(m, name, op); 1232 m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, 1233 segmentAttr, i); 1234 continue; 1235 } 1236 1237 m = opClass.addMethod("::mlir::OperandRangeRange", name); 1238 ERROR_IF_PRUNED(m, name, op); 1239 m->body() << " return getODSOperands(" << i << ").split(" 1240 << segmentAttr << "Attr());"; 1241 } else if (operand.isVariadic()) { 1242 m = opClass.addMethod(rangeType, name); 1243 ERROR_IF_PRUNED(m, name, op); 1244 m->body() << " return getODSOperands(" << i << ");"; 1245 } else { 1246 m = opClass.addMethod("::mlir::Value", name); 1247 ERROR_IF_PRUNED(m, name, op); 1248 m->body() << " return *getODSOperands(" << i << ").begin();"; 1249 } 1250 } 1251 } 1252 } 1253 1254 void OpEmitter::genNamedOperandGetters() { 1255 // Build the code snippet used for initializing the operand_segment_size)s 1256 // array. 1257 std::string attrSizeInitCode; 1258 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 1259 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, 1260 emitHelper.getAttr(operandSegmentAttrName)); 1261 } 1262 1263 generateNamedOperandGetters( 1264 op, opClass, 1265 /*isAdaptor=*/false, 1266 /*sizeAttrInit=*/attrSizeInitCode, 1267 /*rangeType=*/"::mlir::Operation::operand_range", 1268 /*rangeBeginCall=*/"getOperation()->operand_begin()", 1269 /*rangeSizeCall=*/"getOperation()->getNumOperands()", 1270 /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); 1271 } 1272 1273 void OpEmitter::genNamedOperandSetters() { 1274 auto *attrSizedOperands = 1275 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); 1276 for (int i = 0, e = op.getNumOperands(); i != e; ++i) { 1277 const auto &operand = op.getOperand(i); 1278 if (operand.name.empty()) 1279 continue; 1280 for (StringRef name : op.getGetterNames(operand.name)) { 1281 auto *m = opClass.addMethod(operand.isVariadicOfVariadic() 1282 ? "::mlir::MutableOperandRangeRange" 1283 : "::mlir::MutableOperandRange", 1284 (name + "Mutable").str()); 1285 ERROR_IF_PRUNED(m, name, op); 1286 auto &body = m->body(); 1287 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" 1288 << " auto mutableRange = " 1289 "::mlir::MutableOperandRange(getOperation(), " 1290 "range.first, range.second"; 1291 if (attrSizedOperands) { 1292 body << formatv( 1293 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, 1294 emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true)); 1295 } 1296 body << ");\n"; 1297 1298 // If this operand is a nested variadic, we split the range into a 1299 // MutableOperandRangeRange that provides a range over all of the 1300 // sub-ranges. 1301 if (operand.isVariadicOfVariadic()) { 1302 body << " return " 1303 "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" 1304 << op.getGetterName( 1305 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) 1306 << "AttrName()));\n"; 1307 } else { 1308 // Otherwise, we use the full range directly. 1309 body << " return mutableRange;\n"; 1310 } 1311 } 1312 } 1313 } 1314 1315 void OpEmitter::genNamedResultGetters() { 1316 const int numResults = op.getNumResults(); 1317 const int numVariadicResults = op.getNumVariableLengthResults(); 1318 const int numNormalResults = numResults - numVariadicResults; 1319 1320 // If we have more than one variadic results, we need more complicated logic 1321 // to calculate the value range for each result. 1322 1323 const auto *sameVariadicSize = 1324 op.getTrait("::mlir::OpTrait::SameVariadicResultSize"); 1325 const auto *attrSizedResults = 1326 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"); 1327 1328 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { 1329 PrintFatalError(op.getLoc(), "op has multiple variadic results but no " 1330 "specification over their sizes"); 1331 } 1332 1333 if (numVariadicResults < 2 && attrSizedResults) { 1334 PrintFatalError(op.getLoc(), "op must have at least two variadic results " 1335 "to use 'AttrSizedResultSegments' trait"); 1336 } 1337 1338 if (attrSizedResults && sameVariadicSize) { 1339 PrintFatalError(op.getLoc(), 1340 "op cannot have both 'AttrSizedResultSegments' and " 1341 "'SameVariadicResultSize' traits"); 1342 } 1343 1344 // Build the initializer string for the result segment size attribute. 1345 std::string attrSizeInitCode; 1346 if (attrSizedResults) { 1347 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, 1348 emitHelper.getAttr(resultSegmentAttrName)); 1349 } 1350 1351 generateValueRangeStartAndEnd( 1352 opClass, "getODSResultIndexAndLength", numVariadicResults, 1353 numNormalResults, "getOperation()->getNumResults()", attrSizedResults, 1354 attrSizeInitCode, op.getResults()); 1355 1356 auto *m = 1357 opClass.addMethod("::mlir::Operation::result_range", "getODSResults", 1358 MethodParameter("unsigned", "index")); 1359 ERROR_IF_PRUNED(m, "getODSResults", op); 1360 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", 1361 "getODSResultIndexAndLength(index)"); 1362 1363 for (int i = 0; i != numResults; ++i) { 1364 const auto &result = op.getResult(i); 1365 if (result.name.empty()) 1366 continue; 1367 for (StringRef name : op.getGetterNames(result.name)) { 1368 if (result.isOptional()) { 1369 m = opClass.addMethod("::mlir::Value", name); 1370 ERROR_IF_PRUNED(m, name, op); 1371 m->body() 1372 << " auto results = getODSResults(" << i << ");\n" 1373 << " return results.empty() ? ::mlir::Value() : *results.begin();"; 1374 } else if (result.isVariadic()) { 1375 m = opClass.addMethod("::mlir::Operation::result_range", name); 1376 ERROR_IF_PRUNED(m, name, op); 1377 m->body() << " return getODSResults(" << i << ");"; 1378 } else { 1379 m = opClass.addMethod("::mlir::Value", name); 1380 ERROR_IF_PRUNED(m, name, op); 1381 m->body() << " return *getODSResults(" << i << ").begin();"; 1382 } 1383 } 1384 } 1385 } 1386 1387 void OpEmitter::genNamedRegionGetters() { 1388 unsigned numRegions = op.getNumRegions(); 1389 for (unsigned i = 0; i < numRegions; ++i) { 1390 const auto ®ion = op.getRegion(i); 1391 if (region.name.empty()) 1392 continue; 1393 1394 for (StringRef name : op.getGetterNames(region.name)) { 1395 // Generate the accessors for a variadic region. 1396 if (region.isVariadic()) { 1397 auto *m = 1398 opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name); 1399 ERROR_IF_PRUNED(m, name, op); 1400 m->body() << formatv(" return (*this)->getRegions().drop_front({0});", 1401 i); 1402 continue; 1403 } 1404 1405 auto *m = opClass.addMethod("::mlir::Region &", name); 1406 ERROR_IF_PRUNED(m, name, op); 1407 m->body() << formatv(" return (*this)->getRegion({0});", i); 1408 } 1409 } 1410 } 1411 1412 void OpEmitter::genNamedSuccessorGetters() { 1413 unsigned numSuccessors = op.getNumSuccessors(); 1414 for (unsigned i = 0; i < numSuccessors; ++i) { 1415 const NamedSuccessor &successor = op.getSuccessor(i); 1416 if (successor.name.empty()) 1417 continue; 1418 1419 for (StringRef name : op.getGetterNames(successor.name)) { 1420 // Generate the accessors for a variadic successor list. 1421 if (successor.isVariadic()) { 1422 auto *m = opClass.addMethod("::mlir::SuccessorRange", name); 1423 ERROR_IF_PRUNED(m, name, op); 1424 m->body() << formatv( 1425 " return {std::next((*this)->successor_begin(), {0}), " 1426 "(*this)->successor_end()};", 1427 i); 1428 continue; 1429 } 1430 1431 auto *m = opClass.addMethod("::mlir::Block *", name); 1432 ERROR_IF_PRUNED(m, name, op); 1433 m->body() << formatv(" return (*this)->getSuccessor({0});", i); 1434 } 1435 } 1436 } 1437 1438 static bool canGenerateUnwrappedBuilder(const Operator &op) { 1439 // If this op does not have native attributes at all, return directly to avoid 1440 // redefining builders. 1441 if (op.getNumNativeAttributes() == 0) 1442 return false; 1443 1444 bool canGenerate = false; 1445 // We are generating builders that take raw values for attributes. We need to 1446 // make sure the native attributes have a meaningful "unwrapped" value type 1447 // different from the wrapped mlir::Attribute type to avoid redefining 1448 // builders. This checks for the op has at least one such native attribute. 1449 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { 1450 const NamedAttribute &namedAttr = op.getAttribute(i); 1451 if (canUseUnwrappedRawValue(namedAttr.attr)) { 1452 canGenerate = true; 1453 break; 1454 } 1455 } 1456 return canGenerate; 1457 } 1458 1459 static bool canInferType(const Operator &op) { 1460 return op.getTrait("::mlir::InferTypeOpInterface::Trait"); 1461 } 1462 1463 void OpEmitter::genSeparateArgParamBuilder() { 1464 SmallVector<AttrParamKind, 2> attrBuilderType; 1465 attrBuilderType.push_back(AttrParamKind::WrappedAttr); 1466 if (canGenerateUnwrappedBuilder(op)) 1467 attrBuilderType.push_back(AttrParamKind::UnwrappedValue); 1468 1469 // Emit with separate builders with or without unwrapped attributes and/or 1470 // inferring result type. 1471 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, 1472 bool inferType) { 1473 SmallVector<MethodParameter> paramList; 1474 SmallVector<std::string, 4> resultNames; 1475 llvm::StringSet<> inferredAttributes; 1476 buildParamList(paramList, inferredAttributes, resultNames, paramKind, 1477 attrType); 1478 1479 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1480 // If the builder is redundant, skip generating the method. 1481 if (!m) 1482 return; 1483 auto &body = m->body(); 1484 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, 1485 /*isRawValueAttr=*/attrType == 1486 AttrParamKind::UnwrappedValue); 1487 1488 // Push all result types to the operation state 1489 1490 if (inferType) { 1491 // Generate builder that infers type too. 1492 // TODO: Subsume this with general checking if type can be 1493 // inferred automatically. 1494 // TODO: Expand to handle regions. 1495 body << formatv(R"( 1496 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; 1497 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), 1498 {1}.location, {1}.operands, 1499 {1}.attributes.getDictionary({1}.getContext()), 1500 /*regions=*/{{}, inferredReturnTypes))) 1501 {1}.addTypes(inferredReturnTypes); 1502 else 1503 ::llvm::report_fatal_error("Failed to infer result type(s).");)", 1504 opClass.getClassName(), builderOpState); 1505 return; 1506 } 1507 1508 switch (paramKind) { 1509 case TypeParamKind::None: 1510 return; 1511 case TypeParamKind::Separate: 1512 for (int i = 0, e = op.getNumResults(); i < e; ++i) { 1513 if (op.getResult(i).isOptional()) 1514 body << " if (" << resultNames[i] << ")\n "; 1515 body << " " << builderOpState << ".addTypes(" << resultNames[i] 1516 << ");\n"; 1517 } 1518 return; 1519 case TypeParamKind::Collective: { 1520 int numResults = op.getNumResults(); 1521 int numVariadicResults = op.getNumVariableLengthResults(); 1522 int numNonVariadicResults = numResults - numVariadicResults; 1523 bool hasVariadicResult = numVariadicResults != 0; 1524 1525 // Avoid emitting "resultTypes.size() >= 0u" which is always true. 1526 if (!(hasVariadicResult && numNonVariadicResults == 0)) 1527 body << " " 1528 << "assert(resultTypes.size() " 1529 << (hasVariadicResult ? ">=" : "==") << " " 1530 << numNonVariadicResults 1531 << "u && \"mismatched number of results\");\n"; 1532 body << " " << builderOpState << ".addTypes(resultTypes);\n"; 1533 } 1534 return; 1535 } 1536 llvm_unreachable("unhandled TypeParamKind"); 1537 }; 1538 1539 // Some of the build methods generated here may be ambiguous, but TableGen's 1540 // ambiguous function detection will elide those ones. 1541 for (auto attrType : attrBuilderType) { 1542 emit(attrType, TypeParamKind::Separate, /*inferType=*/false); 1543 if (canInferType(op) && op.getNumRegions() == 0) 1544 emit(attrType, TypeParamKind::None, /*inferType=*/true); 1545 emit(attrType, TypeParamKind::Collective, /*inferType=*/false); 1546 } 1547 } 1548 1549 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { 1550 int numResults = op.getNumResults(); 1551 1552 // Signature 1553 SmallVector<MethodParameter> paramList; 1554 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1555 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1556 paramList.emplace_back("::mlir::ValueRange", "operands"); 1557 // Provide default value for `attributes` when its the last parameter 1558 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 1559 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1560 "attributes", attributesDefaultValue); 1561 if (op.getNumVariadicRegions()) 1562 paramList.emplace_back("unsigned", "numRegions"); 1563 1564 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1565 // If the builder is redundant, skip generating the method 1566 if (!m) 1567 return; 1568 auto &body = m->body(); 1569 1570 // Operands 1571 body << " " << builderOpState << ".addOperands(operands);\n"; 1572 1573 // Attributes 1574 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1575 1576 // Create the correct number of regions 1577 if (int numRegions = op.getNumRegions()) { 1578 body << llvm::formatv( 1579 " for (unsigned i = 0; i != {0}; ++i)\n", 1580 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 1581 body << " (void)" << builderOpState << ".addRegion();\n"; 1582 } 1583 1584 // Result types 1585 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()"); 1586 body << " " << builderOpState << ".addTypes({" 1587 << llvm::join(resultTypes, ", ") << "});\n\n"; 1588 } 1589 1590 void OpEmitter::genInferredTypeCollectiveParamBuilder() { 1591 SmallVector<MethodParameter> paramList; 1592 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1593 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1594 paramList.emplace_back("::mlir::ValueRange", "operands"); 1595 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 1596 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1597 "attributes", attributesDefaultValue); 1598 if (op.getNumVariadicRegions()) 1599 paramList.emplace_back("unsigned", "numRegions"); 1600 1601 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1602 // If the builder is redundant, skip generating the method 1603 if (!m) 1604 return; 1605 auto &body = m->body(); 1606 1607 int numResults = op.getNumResults(); 1608 int numVariadicResults = op.getNumVariableLengthResults(); 1609 int numNonVariadicResults = numResults - numVariadicResults; 1610 1611 int numOperands = op.getNumOperands(); 1612 int numVariadicOperands = op.getNumVariableLengthOperands(); 1613 int numNonVariadicOperands = numOperands - numVariadicOperands; 1614 1615 // Operands 1616 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 1617 body << " assert(operands.size()" 1618 << (numVariadicOperands != 0 ? " >= " : " == ") 1619 << numNonVariadicOperands 1620 << "u && \"mismatched number of parameters\");\n"; 1621 body << " " << builderOpState << ".addOperands(operands);\n"; 1622 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1623 1624 // Create the correct number of regions 1625 if (int numRegions = op.getNumRegions()) { 1626 body << llvm::formatv( 1627 " for (unsigned i = 0; i != {0}; ++i)\n", 1628 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 1629 body << " (void)" << builderOpState << ".addRegion();\n"; 1630 } 1631 1632 // Result types 1633 body << formatv(R"( 1634 ::mlir::SmallVector<::mlir::Type, 2> inferredReturnTypes; 1635 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), 1636 {1}.location, operands, 1637 {1}.attributes.getDictionary({1}.getContext()), 1638 {1}.regions, inferredReturnTypes))) {{)", 1639 opClass.getClassName(), builderOpState); 1640 if (numVariadicResults == 0 || numNonVariadicResults != 0) 1641 body << "\n assert(inferredReturnTypes.size()" 1642 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 1643 << "u && \"mismatched number of return types\");"; 1644 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);"; 1645 1646 body << formatv(R"( 1647 } else {{ 1648 ::llvm::report_fatal_error("Failed to infer result type(s)."); 1649 })", 1650 opClass.getClassName(), builderOpState); 1651 } 1652 1653 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { 1654 auto emit = [&](AttrParamKind attrType) { 1655 SmallVector<MethodParameter> paramList; 1656 SmallVector<std::string, 4> resultNames; 1657 llvm::StringSet<> inferredAttributes; 1658 buildParamList(paramList, inferredAttributes, resultNames, 1659 TypeParamKind::None, attrType); 1660 1661 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1662 // If the builder is redundant, skip generating the method 1663 if (!m) 1664 return; 1665 auto &body = m->body(); 1666 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, 1667 /*isRawValueAttr=*/attrType == 1668 AttrParamKind::UnwrappedValue); 1669 1670 auto numResults = op.getNumResults(); 1671 if (numResults == 0) 1672 return; 1673 1674 // Push all result types to the operation state 1675 const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; 1676 std::string resultType = 1677 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); 1678 body << " " << builderOpState << ".addTypes({" << resultType; 1679 for (int i = 1; i != numResults; ++i) 1680 body << ", " << resultType; 1681 body << "});\n\n"; 1682 }; 1683 1684 emit(AttrParamKind::WrappedAttr); 1685 // Generate additional builder(s) if any attributes can be "unwrapped" 1686 if (canGenerateUnwrappedBuilder(op)) 1687 emit(AttrParamKind::UnwrappedValue); 1688 } 1689 1690 void OpEmitter::genUseAttrAsResultTypeBuilder() { 1691 SmallVector<MethodParameter> paramList; 1692 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1693 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1694 paramList.emplace_back("::mlir::ValueRange", "operands"); 1695 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1696 "attributes", "{}"); 1697 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1698 // If the builder is redundant, skip generating the method 1699 if (!m) 1700 return; 1701 1702 auto &body = m->body(); 1703 1704 // Push all result types to the operation state 1705 std::string resultType; 1706 const auto &namedAttr = op.getAttribute(0); 1707 1708 body << " auto attrName = " << op.getGetterName(namedAttr.name) 1709 << "AttrName(" << builderOpState 1710 << ".name);\n" 1711 " for (auto attr : attributes) {\n" 1712 " if (attr.getName() != attrName) continue;\n"; 1713 if (namedAttr.attr.isTypeAttr()) { 1714 resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()"; 1715 } else { 1716 resultType = "attr.getValue().getType()"; 1717 } 1718 1719 // Operands 1720 body << " " << builderOpState << ".addOperands(operands);\n"; 1721 1722 // Attributes 1723 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1724 1725 // Result types 1726 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); 1727 body << " " << builderOpState << ".addTypes({" 1728 << llvm::join(resultTypes, ", ") << "});\n"; 1729 body << " }\n"; 1730 } 1731 1732 /// Returns a signature of the builder. Updates the context `fctx` to enable 1733 /// replacement of $_builder and $_state in the body. 1734 static SmallVector<MethodParameter> 1735 getBuilderSignature(const Builder &builder) { 1736 ArrayRef<Builder::Parameter> params(builder.getParameters()); 1737 1738 // Inject builder and state arguments. 1739 SmallVector<MethodParameter> arguments; 1740 arguments.reserve(params.size() + 2); 1741 arguments.emplace_back("::mlir::OpBuilder &", odsBuilder); 1742 arguments.emplace_back("::mlir::OperationState &", builderOpState); 1743 1744 for (unsigned i = 0, e = params.size(); i < e; ++i) { 1745 // If no name is provided, generate one. 1746 Optional<StringRef> paramName = params[i].getName(); 1747 std::string name = 1748 paramName ? paramName->str() : "odsArg" + std::to_string(i); 1749 1750 StringRef defaultValue; 1751 if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue()) 1752 defaultValue = *defaultParamValue; 1753 1754 arguments.emplace_back(params[i].getCppType(), std::move(name), 1755 defaultValue); 1756 } 1757 1758 return arguments; 1759 } 1760 1761 void OpEmitter::genBuilder() { 1762 // Handle custom builders if provided. 1763 for (const Builder &builder : op.getBuilders()) { 1764 SmallVector<MethodParameter> arguments = getBuilderSignature(builder); 1765 1766 Optional<StringRef> body = builder.getBody(); 1767 auto properties = body ? Method::Static : Method::StaticDeclaration; 1768 auto *method = 1769 opClass.addMethod("void", "build", properties, std::move(arguments)); 1770 if (body) 1771 ERROR_IF_PRUNED(method, "build", op); 1772 1773 FmtContext fctx; 1774 fctx.withBuilder(odsBuilder); 1775 fctx.addSubst("_state", builderOpState); 1776 if (body) 1777 method->body() << tgfmt(*body, &fctx); 1778 } 1779 1780 // Generate default builders that requires all result type, operands, and 1781 // attributes as parameters. 1782 if (op.skipDefaultBuilders()) 1783 return; 1784 1785 // We generate three classes of builders here: 1786 // 1. one having a stand-alone parameter for each operand / attribute, and 1787 genSeparateArgParamBuilder(); 1788 // 2. one having an aggregated parameter for all result types / operands / 1789 // attributes, and 1790 genCollectiveParamBuilder(); 1791 // 3. one having a stand-alone parameter for each operand and attribute, 1792 // use the first operand or attribute's type as all result types 1793 // to facilitate different call patterns. 1794 if (op.getNumVariableLengthResults() == 0) { 1795 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { 1796 genUseOperandAsResultTypeSeparateParamBuilder(); 1797 genUseOperandAsResultTypeCollectiveParamBuilder(); 1798 } 1799 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) 1800 genUseAttrAsResultTypeBuilder(); 1801 } 1802 } 1803 1804 void OpEmitter::genCollectiveParamBuilder() { 1805 int numResults = op.getNumResults(); 1806 int numVariadicResults = op.getNumVariableLengthResults(); 1807 int numNonVariadicResults = numResults - numVariadicResults; 1808 1809 int numOperands = op.getNumOperands(); 1810 int numVariadicOperands = op.getNumVariableLengthOperands(); 1811 int numNonVariadicOperands = numOperands - numVariadicOperands; 1812 1813 SmallVector<MethodParameter> paramList; 1814 paramList.emplace_back("::mlir::OpBuilder &", ""); 1815 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1816 paramList.emplace_back("::mlir::TypeRange", "resultTypes"); 1817 paramList.emplace_back("::mlir::ValueRange", "operands"); 1818 // Provide default value for `attributes` when its the last parameter 1819 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; 1820 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", 1821 "attributes", attributesDefaultValue); 1822 if (op.getNumVariadicRegions()) 1823 paramList.emplace_back("unsigned", "numRegions"); 1824 1825 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); 1826 // If the builder is redundant, skip generating the method 1827 if (!m) 1828 return; 1829 auto &body = m->body(); 1830 1831 // Operands 1832 if (numVariadicOperands == 0 || numNonVariadicOperands != 0) 1833 body << " assert(operands.size()" 1834 << (numVariadicOperands != 0 ? " >= " : " == ") 1835 << numNonVariadicOperands 1836 << "u && \"mismatched number of parameters\");\n"; 1837 body << " " << builderOpState << ".addOperands(operands);\n"; 1838 1839 // Attributes 1840 body << " " << builderOpState << ".addAttributes(attributes);\n"; 1841 1842 // Create the correct number of regions 1843 if (int numRegions = op.getNumRegions()) { 1844 body << llvm::formatv( 1845 " for (unsigned i = 0; i != {0}; ++i)\n", 1846 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); 1847 body << " (void)" << builderOpState << ".addRegion();\n"; 1848 } 1849 1850 // Result types 1851 if (numVariadicResults == 0 || numNonVariadicResults != 0) 1852 body << " assert(resultTypes.size()" 1853 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults 1854 << "u && \"mismatched number of return types\");\n"; 1855 body << " " << builderOpState << ".addTypes(resultTypes);\n"; 1856 1857 // Generate builder that infers type too. 1858 // TODO: Expand to handle successors. 1859 if (canInferType(op) && op.getNumSuccessors() == 0) 1860 genInferredTypeCollectiveParamBuilder(); 1861 } 1862 1863 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList, 1864 llvm::StringSet<> &inferredAttributes, 1865 SmallVectorImpl<std::string> &resultTypeNames, 1866 TypeParamKind typeParamKind, 1867 AttrParamKind attrParamKind) { 1868 resultTypeNames.clear(); 1869 auto numResults = op.getNumResults(); 1870 resultTypeNames.reserve(numResults); 1871 1872 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); 1873 paramList.emplace_back("::mlir::OperationState &", builderOpState); 1874 1875 switch (typeParamKind) { 1876 case TypeParamKind::None: 1877 break; 1878 case TypeParamKind::Separate: { 1879 // Add parameters for all return types 1880 for (int i = 0; i < numResults; ++i) { 1881 const auto &result = op.getResult(i); 1882 std::string resultName = std::string(result.name); 1883 if (resultName.empty()) 1884 resultName = std::string(formatv("resultType{0}", i)); 1885 1886 StringRef type = 1887 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; 1888 1889 paramList.emplace_back(type, resultName, result.isOptional()); 1890 resultTypeNames.emplace_back(std::move(resultName)); 1891 } 1892 } break; 1893 case TypeParamKind::Collective: { 1894 paramList.emplace_back("::mlir::TypeRange", "resultTypes"); 1895 resultTypeNames.push_back("resultTypes"); 1896 } break; 1897 } 1898 1899 // Add parameters for all arguments (operands and attributes). 1900 int defaultValuedAttrStartIndex = op.getNumArgs(); 1901 // Successors and variadic regions go at the end of the parameter list, so no 1902 // default arguments are possible. 1903 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions(); 1904 if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) { 1905 // Calculate the start index from which we can attach default values in the 1906 // builder declaration. 1907 for (int i = op.getNumArgs() - 1; i >= 0; --i) { 1908 auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>(); 1909 if (!namedAttr || !namedAttr->attr.hasDefaultValue()) 1910 break; 1911 1912 if (!canUseUnwrappedRawValue(namedAttr->attr)) 1913 break; 1914 1915 // Creating an APInt requires us to provide bitwidth, value, and 1916 // signedness, which is complicated compared to others. Similarly 1917 // for APFloat. 1918 // TODO: Adjust the 'returnType' field of such attributes 1919 // to support them. 1920 StringRef retType = namedAttr->attr.getReturnType(); 1921 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") 1922 break; 1923 1924 defaultValuedAttrStartIndex = i; 1925 } 1926 } 1927 1928 /// Collect any inferred attributes. 1929 for (const NamedTypeConstraint &operand : op.getOperands()) { 1930 if (operand.isVariadicOfVariadic()) { 1931 inferredAttributes.insert( 1932 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); 1933 } 1934 } 1935 1936 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { 1937 Argument arg = op.getArg(i); 1938 if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) { 1939 StringRef type; 1940 if (operand->isVariadicOfVariadic()) 1941 type = "::llvm::ArrayRef<::mlir::ValueRange>"; 1942 else if (operand->isVariadic()) 1943 type = "::mlir::ValueRange"; 1944 else 1945 type = "::mlir::Value"; 1946 1947 paramList.emplace_back(type, getArgumentName(op, numOperands++), 1948 operand->isOptional()); 1949 continue; 1950 } 1951 const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>(); 1952 const Attribute &attr = namedAttr.attr; 1953 1954 // Inferred attributes don't need to be added to the param list. 1955 if (inferredAttributes.contains(namedAttr.name)) 1956 continue; 1957 1958 StringRef type; 1959 switch (attrParamKind) { 1960 case AttrParamKind::WrappedAttr: 1961 type = attr.getStorageType(); 1962 break; 1963 case AttrParamKind::UnwrappedValue: 1964 if (canUseUnwrappedRawValue(attr)) 1965 type = attr.getReturnType(); 1966 else 1967 type = attr.getStorageType(); 1968 break; 1969 } 1970 1971 // Attach default value if requested and possible. 1972 std::string defaultValue; 1973 if (attrParamKind == AttrParamKind::UnwrappedValue && 1974 i >= defaultValuedAttrStartIndex) { 1975 defaultValue += attr.getDefaultValue(); 1976 } 1977 paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue), 1978 attr.isOptional()); 1979 } 1980 1981 /// Insert parameters for each successor. 1982 for (const NamedSuccessor &succ : op.getSuccessors()) { 1983 StringRef type = 1984 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *"; 1985 paramList.emplace_back(type, succ.name); 1986 } 1987 1988 /// Insert parameters for variadic regions. 1989 for (const NamedRegion ®ion : op.getRegions()) 1990 if (region.isVariadic()) 1991 paramList.emplace_back("unsigned", 1992 llvm::formatv("{0}Count", region.name).str()); 1993 } 1994 1995 void OpEmitter::genCodeForAddingArgAndRegionForBuilder( 1996 MethodBody &body, llvm::StringSet<> &inferredAttributes, 1997 bool isRawValueAttr) { 1998 // Push all operands to the result. 1999 for (int i = 0, e = op.getNumOperands(); i < e; ++i) { 2000 std::string argName = getArgumentName(op, i); 2001 const NamedTypeConstraint &operand = op.getOperand(i); 2002 if (operand.constraint.isVariadicOfVariadic()) { 2003 body << " for (::mlir::ValueRange range : " << argName << ")\n " 2004 << builderOpState << ".addOperands(range);\n"; 2005 2006 // Add the segment attribute. 2007 body << " {\n" 2008 << " ::llvm::SmallVector<int32_t> rangeSegments;\n" 2009 << " for (::mlir::ValueRange range : " << argName << ")\n" 2010 << " rangeSegments.push_back(range.size());\n" 2011 << " " << builderOpState << ".addAttribute(" 2012 << op.getGetterName( 2013 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) 2014 << "AttrName(" << builderOpState << ".name), " << odsBuilder 2015 << ".getI32TensorAttr(rangeSegments));" 2016 << " }\n"; 2017 continue; 2018 } 2019 2020 if (operand.isOptional()) 2021 body << " if (" << argName << ")\n "; 2022 body << " " << builderOpState << ".addOperands(" << argName << ");\n"; 2023 } 2024 2025 // If the operation has the operand segment size attribute, add it here. 2026 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 2027 std::string sizes = op.getGetterName(operandSegmentAttrName); 2028 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" 2029 << builderOpState << ".name), " 2030 << "odsBuilder.getI32VectorAttr({"; 2031 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) { 2032 const NamedTypeConstraint &operand = op.getOperand(i); 2033 if (!operand.isVariableLength()) { 2034 body << "1"; 2035 return; 2036 } 2037 2038 std::string operandName = getArgumentName(op, i); 2039 if (operand.isOptional()) { 2040 body << "(" << operandName << " ? 1 : 0)"; 2041 } else if (operand.isVariadicOfVariadic()) { 2042 body << llvm::formatv( 2043 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, " 2044 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " 2045 "range.size(); }))", 2046 operandName); 2047 } else { 2048 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())"; 2049 } 2050 }); 2051 body << "}));\n"; 2052 } 2053 2054 // Push all attributes to the result. 2055 for (const auto &namedAttr : op.getAttributes()) { 2056 auto &attr = namedAttr.attr; 2057 if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) 2058 continue; 2059 2060 bool emitNotNullCheck = 2061 attr.isOptional() || (attr.hasDefaultValue() && !isRawValueAttr); 2062 if (emitNotNullCheck) 2063 body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; 2064 2065 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { 2066 // If this is a raw value, then we need to wrap it in an Attribute 2067 // instance. 2068 FmtContext fctx; 2069 fctx.withBuilder("odsBuilder"); 2070 2071 std::string builderTemplate = std::string(attr.getConstBuilderTemplate()); 2072 2073 // For StringAttr, its constant builder call will wrap the input in 2074 // quotes, which is correct for normal string literals, but incorrect 2075 // here given we use function arguments. So we need to strip the 2076 // wrapping quotes. 2077 if (StringRef(builderTemplate).contains("\"$0\"")) 2078 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); 2079 2080 std::string value = 2081 std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); 2082 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", 2083 builderOpState, op.getGetterName(namedAttr.name), value); 2084 } else { 2085 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", 2086 builderOpState, op.getGetterName(namedAttr.name), 2087 namedAttr.name); 2088 } 2089 if (emitNotNullCheck) 2090 body << " }\n"; 2091 } 2092 2093 // Create the correct number of regions. 2094 for (const NamedRegion ®ion : op.getRegions()) { 2095 if (region.isVariadic()) 2096 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", 2097 region.name); 2098 2099 body << " (void)" << builderOpState << ".addRegion();\n"; 2100 } 2101 2102 // Push all successors to the result. 2103 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { 2104 body << formatv(" {0}.addSuccessors({1});\n", builderOpState, 2105 namedSuccessor.name); 2106 } 2107 } 2108 2109 void OpEmitter::genCanonicalizerDecls() { 2110 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod"); 2111 if (hasCanonicalizeMethod) { 2112 // static LogicResult FooOp:: 2113 // canonicalize(FooOp op, PatternRewriter &rewriter); 2114 SmallVector<MethodParameter> paramList; 2115 paramList.emplace_back(op.getCppClassName(), "op"); 2116 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); 2117 auto *m = opClass.declareStaticMethod("::mlir::LogicalResult", 2118 "canonicalize", std::move(paramList)); 2119 ERROR_IF_PRUNED(m, "canonicalize", op); 2120 } 2121 2122 // We get a prototype for 'getCanonicalizationPatterns' if requested directly 2123 // or if using a 'canonicalize' method. 2124 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer"); 2125 if (!hasCanonicalizeMethod && !hasCanonicalizer) 2126 return; 2127 2128 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize' 2129 // method, but not implementing 'getCanonicalizationPatterns' manually. 2130 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer; 2131 2132 // Add a signature for getCanonicalizationPatterns if implemented by the 2133 // dialect or if synthesized to call 'canonicalize'. 2134 SmallVector<MethodParameter> paramList; 2135 paramList.emplace_back("::mlir::RewritePatternSet &", "results"); 2136 paramList.emplace_back("::mlir::MLIRContext *", "context"); 2137 auto kind = hasBody ? Method::Static : Method::StaticDeclaration; 2138 auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind, 2139 std::move(paramList)); 2140 2141 // If synthesizing the method, fill it it. 2142 if (hasBody) { 2143 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op); 2144 method->body() << " results.add(canonicalize);\n"; 2145 } 2146 } 2147 2148 void OpEmitter::genFolderDecls() { 2149 bool hasSingleResult = 2150 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; 2151 2152 if (def.getValueAsBit("hasFolder")) { 2153 if (hasSingleResult) { 2154 auto *m = opClass.declareMethod( 2155 "::mlir::OpFoldResult", "fold", 2156 MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands")); 2157 ERROR_IF_PRUNED(m, "operands", op); 2158 } else { 2159 SmallVector<MethodParameter> paramList; 2160 paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); 2161 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", 2162 "results"); 2163 auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold", 2164 std::move(paramList)); 2165 ERROR_IF_PRUNED(m, "fold", op); 2166 } 2167 } 2168 } 2169 2170 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { 2171 Interface interface = opTrait->getInterface(); 2172 2173 // Get the set of methods that should always be declared. 2174 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); 2175 llvm::StringSet<> alwaysDeclaredMethods; 2176 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), 2177 alwaysDeclaredMethodsVec.end()); 2178 2179 for (const InterfaceMethod &method : interface.getMethods()) { 2180 // Don't declare if the method has a body. 2181 if (method.getBody()) 2182 continue; 2183 // Don't declare if the method has a default implementation and the op 2184 // didn't request that it always be declared. 2185 if (method.getDefaultImplementation() && 2186 !alwaysDeclaredMethods.count(method.getName())) 2187 continue; 2188 // Interface methods are allowed to overlap with existing methods, so don't 2189 // check if pruned. 2190 (void)genOpInterfaceMethod(method); 2191 } 2192 } 2193 2194 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, 2195 bool declaration) { 2196 SmallVector<MethodParameter> paramList; 2197 for (const InterfaceMethod::Argument &arg : method.getArguments()) 2198 paramList.emplace_back(arg.type, arg.name); 2199 2200 auto props = (method.isStatic() ? Method::Static : Method::None) | 2201 (declaration ? Method::Declaration : Method::None); 2202 return opClass.addMethod(method.getReturnType(), method.getName(), props, 2203 std::move(paramList)); 2204 } 2205 2206 void OpEmitter::genOpInterfaceMethods() { 2207 for (const auto &trait : op.getTraits()) { 2208 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) 2209 if (opTrait->shouldDeclareMethods()) 2210 genOpInterfaceMethods(opTrait); 2211 } 2212 } 2213 2214 void OpEmitter::genSideEffectInterfaceMethods() { 2215 enum EffectKind { Operand, Result, Symbol, Static }; 2216 struct EffectLocation { 2217 /// The effect applied. 2218 SideEffect effect; 2219 2220 /// The index if the kind is not static. 2221 unsigned index; 2222 2223 /// The kind of the location. 2224 unsigned kind; 2225 }; 2226 2227 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects; 2228 auto resolveDecorators = [&](Operator::var_decorator_range decorators, 2229 unsigned index, unsigned kind) { 2230 for (auto decorator : decorators) 2231 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) { 2232 opClass.addTrait(effect->getInterfaceTrait()); 2233 interfaceEffects[effect->getBaseEffectName()].push_back( 2234 EffectLocation{*effect, index, kind}); 2235 } 2236 }; 2237 2238 // Collect effects that were specified via: 2239 /// Traits. 2240 for (const auto &trait : op.getTraits()) { 2241 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait); 2242 if (!opTrait) 2243 continue; 2244 auto &effects = interfaceEffects[opTrait->getBaseEffectName()]; 2245 for (auto decorator : opTrait->getEffects()) 2246 effects.push_back(EffectLocation{cast<SideEffect>(decorator), 2247 /*index=*/0, EffectKind::Static}); 2248 } 2249 /// Attributes and Operands. 2250 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { 2251 Argument arg = op.getArg(i); 2252 if (arg.is<NamedTypeConstraint *>()) { 2253 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); 2254 ++operandIt; 2255 continue; 2256 } 2257 const NamedAttribute *attr = arg.get<NamedAttribute *>(); 2258 if (attr->attr.getBaseAttr().isSymbolRefAttr()) 2259 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol); 2260 } 2261 /// Results. 2262 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) 2263 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); 2264 2265 // The code used to add an effect instance. 2266 // {0}: The effect class. 2267 // {1}: Optional value or symbol reference. 2268 // {1}: The resource class. 2269 const char *addEffectCode = 2270 " effects.emplace_back({0}::get(), {1}{2}::get());\n"; 2271 2272 for (auto &it : interfaceEffects) { 2273 // Generate the 'getEffects' method. 2274 std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::" 2275 "SideEffects::EffectInstance<{0}>> &", 2276 it.first()) 2277 .str(); 2278 auto *getEffects = opClass.addMethod("void", "getEffects", 2279 MethodParameter(type, "effects")); 2280 ERROR_IF_PRUNED(getEffects, "getEffects", op); 2281 auto &body = getEffects->body(); 2282 2283 // Add effect instances for each of the locations marked on the operation. 2284 for (auto &location : it.second) { 2285 StringRef effect = location.effect.getName(); 2286 StringRef resource = location.effect.getResource(); 2287 if (location.kind == EffectKind::Static) { 2288 // A static instance has no attached value. 2289 body << llvm::formatv(addEffectCode, effect, "", resource).str(); 2290 } else if (location.kind == EffectKind::Symbol) { 2291 // A symbol reference requires adding the proper attribute. 2292 const auto *attr = op.getArg(location.index).get<NamedAttribute *>(); 2293 std::string argName = op.getGetterName(attr->name); 2294 if (attr->attr.isOptional()) { 2295 body << " if (auto symbolRef = " << argName << "Attr())\n " 2296 << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource) 2297 .str(); 2298 } else { 2299 body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ", 2300 resource) 2301 .str(); 2302 } 2303 } else { 2304 // Otherwise this is an operand/result, so we need to attach the Value. 2305 body << " for (::mlir::Value value : getODS" 2306 << (location.kind == EffectKind::Operand ? "Operands" : "Results") 2307 << "(" << location.index << "))\n " 2308 << llvm::formatv(addEffectCode, effect, "value, ", resource).str(); 2309 } 2310 } 2311 } 2312 } 2313 2314 void OpEmitter::genTypeInterfaceMethods() { 2315 if (!op.allResultTypesKnown()) 2316 return; 2317 // Generate 'inferReturnTypes' method declaration using the interface method 2318 // declared in 'InferTypeOpInterface' op interface. 2319 const auto *trait = 2320 cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait")); 2321 Interface interface = trait->getInterface(); 2322 Method *method = [&]() -> Method * { 2323 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { 2324 if (interfaceMethod.getName() == "inferReturnTypes") { 2325 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); 2326 } 2327 } 2328 assert(0 && "unable to find inferReturnTypes interface method"); 2329 return nullptr; 2330 }(); 2331 ERROR_IF_PRUNED(method, "inferReturnTypes", op); 2332 auto &body = method->body(); 2333 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; 2334 2335 FmtContext fctx; 2336 fctx.withBuilder("odsBuilder"); 2337 body << " ::mlir::Builder odsBuilder(context);\n"; 2338 2339 auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & { 2340 if (!type.isArg()) 2341 return body << tgfmt(*type.getType().getBuilderCall(), &fctx); 2342 auto argIndex = type.getArg(); 2343 assert(!op.getArg(argIndex).is<NamedAttribute *>()); 2344 auto arg = op.getArgToOperandOrAttribute(argIndex); 2345 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) 2346 return body << "operands[" << arg.operandOrAttributeIndex() 2347 << "].getType()"; 2348 return body << "attributes[" << arg.operandOrAttributeIndex() 2349 << "].getType()"; 2350 }; 2351 2352 for (int i = 0, e = op.getNumResults(); i != e; ++i) { 2353 body << " inferredReturnTypes[" << i << "] = "; 2354 auto types = op.getSameTypeAsResult(i); 2355 emitType(types[0]) << ";\n"; 2356 if (types.size() == 1) 2357 continue; 2358 // TODO: We could verify equality here, but skipping that for verification. 2359 } 2360 body << " return ::mlir::success();"; 2361 } 2362 2363 void OpEmitter::genParser() { 2364 if (hasStringAttribute(def, "assemblyFormat")) 2365 return; 2366 2367 if (!def.getValueAsBit("hasCustomAssemblyFormat")) 2368 return; 2369 2370 SmallVector<MethodParameter> paramList; 2371 paramList.emplace_back("::mlir::OpAsmParser &", "parser"); 2372 paramList.emplace_back("::mlir::OperationState &", "result"); 2373 2374 auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse", 2375 std::move(paramList)); 2376 ERROR_IF_PRUNED(method, "parse", op); 2377 } 2378 2379 void OpEmitter::genPrinter() { 2380 if (hasStringAttribute(def, "assemblyFormat")) 2381 return; 2382 2383 // Check to see if this op uses a c++ format. 2384 if (!def.getValueAsBit("hasCustomAssemblyFormat")) 2385 return; 2386 auto *method = opClass.declareMethod( 2387 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p")); 2388 ERROR_IF_PRUNED(method, "print", op); 2389 } 2390 2391 void OpEmitter::genVerifier() { 2392 auto *implMethod = 2393 opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl"); 2394 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op); 2395 auto &implBody = implMethod->body(); 2396 2397 populateSubstitutions(emitHelper, verifyCtx); 2398 genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter); 2399 genOperandResultVerifier(implBody, op.getOperands(), "operand"); 2400 genOperandResultVerifier(implBody, op.getResults(), "result"); 2401 2402 for (auto &trait : op.getTraits()) { 2403 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) { 2404 implBody << tgfmt(" if (!($0))\n " 2405 "return emitOpError(\"failed to verify that $1\");\n", 2406 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), 2407 t->getSummary()); 2408 } 2409 } 2410 2411 genRegionVerifier(implBody); 2412 genSuccessorVerifier(implBody); 2413 2414 implBody << " return ::mlir::success();\n"; 2415 2416 // TODO: Some places use the `verifyInvariants` to do operation verification. 2417 // This may not act as their expectation because this doesn't call any 2418 // verifiers of native/interface traits. Needs to review those use cases and 2419 // see if we should use the mlir::verify() instead. 2420 auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants"); 2421 ERROR_IF_PRUNED(method, "verifyInvariants", op); 2422 auto &body = method->body(); 2423 if (def.getValueAsBit("hasVerifier")) { 2424 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && " 2425 "::mlir::succeeded(verify()))\n"; 2426 body << " return ::mlir::success();\n"; 2427 body << " return ::mlir::failure();"; 2428 } else { 2429 body << " return verifyInvariantsImpl();"; 2430 } 2431 } 2432 2433 void OpEmitter::genCustomVerifier() { 2434 if (def.getValueAsBit("hasVerifier")) { 2435 auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify"); 2436 ERROR_IF_PRUNED(method, "verify", op); 2437 } 2438 2439 if (def.getValueAsBit("hasRegionVerifier")) { 2440 auto *method = 2441 opClass.declareMethod("::mlir::LogicalResult", "verifyRegions"); 2442 ERROR_IF_PRUNED(method, "verifyRegions", op); 2443 } 2444 } 2445 2446 void OpEmitter::genOperandResultVerifier(MethodBody &body, 2447 Operator::const_value_range values, 2448 StringRef valueKind) { 2449 // Check that an optional value is at most 1 element. 2450 // 2451 // {0}: Value index. 2452 // {1}: "operand" or "result" 2453 const char *const verifyOptional = R"( 2454 if (valueGroup{0}.size() > 1) { 2455 return emitOpError("{1} group starting at #") << index 2456 << " requires 0 or 1 element, but found " << valueGroup{0}.size(); 2457 } 2458 )"; 2459 // Check the types of a range of values. 2460 // 2461 // {0}: Value index. 2462 // {1}: Type constraint function. 2463 // {2}: "operand" or "result" 2464 const char *const verifyValues = R"( 2465 for (auto v : valueGroup{0}) { 2466 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++))) 2467 return ::mlir::failure(); 2468 } 2469 )"; 2470 2471 const auto canSkip = [](const NamedTypeConstraint &value) { 2472 return !value.hasPredicate() && !value.isOptional() && 2473 !value.isVariadicOfVariadic(); 2474 }; 2475 if (values.empty() || llvm::all_of(values, canSkip)) 2476 return; 2477 2478 FmtContext fctx; 2479 2480 body << " {\n unsigned index = 0; (void)index;\n"; 2481 2482 for (const auto &staticValue : llvm::enumerate(values)) { 2483 const NamedTypeConstraint &value = staticValue.value(); 2484 2485 bool hasPredicate = value.hasPredicate(); 2486 bool isOptional = value.isOptional(); 2487 bool isVariadicOfVariadic = value.isVariadicOfVariadic(); 2488 if (!hasPredicate && !isOptional && !isVariadicOfVariadic) 2489 continue; 2490 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", 2491 // Capitalize the first letter to match the function name 2492 valueKind.substr(0, 1).upper(), valueKind.substr(1), 2493 staticValue.index()); 2494 2495 // If the constraint is optional check that the value group has at most 1 2496 // value. 2497 if (isOptional) { 2498 body << formatv(verifyOptional, staticValue.index(), valueKind); 2499 } else if (isVariadicOfVariadic) { 2500 body << formatv( 2501 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" 2502 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n" 2503 " return ::mlir::failure();\n", 2504 value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name, 2505 staticValue.index()); 2506 } 2507 2508 // Otherwise, if there is no predicate there is nothing left to do. 2509 if (!hasPredicate) 2510 continue; 2511 // Emit a loop to check all the dynamic values in the pack. 2512 StringRef constraintFn = 2513 staticVerifierEmitter.getTypeConstraintFn(value.constraint); 2514 body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind); 2515 } 2516 2517 body << " }\n"; 2518 } 2519 2520 void OpEmitter::genRegionVerifier(MethodBody &body) { 2521 /// Code to verify a region. 2522 /// 2523 /// {0}: Getter for the regions. 2524 /// {1}: The region constraint. 2525 /// {2}: The region's name. 2526 /// {3}: The region description. 2527 const char *const verifyRegion = R"( 2528 for (auto ®ion : {0}) 2529 if (::mlir::failed({1}(*this, region, "{2}", index++))) 2530 return ::mlir::failure(); 2531 )"; 2532 /// Get a single region. 2533 /// 2534 /// {0}: The region's index. 2535 const char *const getSingleRegion = 2536 "::llvm::makeMutableArrayRef((*this)->getRegion({0}))"; 2537 2538 // If we have no regions, there is nothing more to do. 2539 const auto canSkip = [](const NamedRegion ®ion) { 2540 return region.constraint.getPredicate().isNull(); 2541 }; 2542 auto regions = op.getRegions(); 2543 if (regions.empty() && llvm::all_of(regions, canSkip)) 2544 return; 2545 2546 body << " {\n unsigned index = 0; (void)index;\n"; 2547 for (const auto &it : llvm::enumerate(regions)) { 2548 const auto ®ion = it.value(); 2549 if (canSkip(region)) 2550 continue; 2551 2552 auto getRegion = region.isVariadic() 2553 ? formatv("{0}()", op.getGetterName(region.name)).str() 2554 : formatv(getSingleRegion, it.index()).str(); 2555 auto constraintFn = 2556 staticVerifierEmitter.getRegionConstraintFn(region.constraint); 2557 body << formatv(verifyRegion, getRegion, constraintFn, region.name); 2558 } 2559 body << " }\n"; 2560 } 2561 2562 void OpEmitter::genSuccessorVerifier(MethodBody &body) { 2563 const char *const verifySuccessor = R"( 2564 for (auto *successor : {0}) 2565 if (::mlir::failed({1}(*this, successor, "{2}", index++))) 2566 return ::mlir::failure(); 2567 )"; 2568 /// Get a single successor. 2569 /// 2570 /// {0}: The successor's name. 2571 const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())"; 2572 2573 // If we have no successors, there is nothing more to do. 2574 const auto canSkip = [](const NamedSuccessor &successor) { 2575 return successor.constraint.getPredicate().isNull(); 2576 }; 2577 auto successors = op.getSuccessors(); 2578 if (successors.empty() && llvm::all_of(successors, canSkip)) 2579 return; 2580 2581 body << " {\n unsigned index = 0; (void)index;\n"; 2582 2583 for (auto &it : llvm::enumerate(successors)) { 2584 const auto &successor = it.value(); 2585 if (canSkip(successor)) 2586 continue; 2587 2588 auto getSuccessor = 2589 formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor, 2590 successor.name, it.index()) 2591 .str(); 2592 auto constraintFn = 2593 staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint); 2594 body << formatv(verifySuccessor, getSuccessor, constraintFn, 2595 successor.name); 2596 } 2597 body << " }\n"; 2598 } 2599 2600 /// Add a size count trait to the given operation class. 2601 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, 2602 int numTotal, int numVariadic) { 2603 if (numVariadic != 0) { 2604 if (numTotal == numVariadic) 2605 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s"); 2606 else 2607 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" + 2608 Twine(numTotal - numVariadic) + ">::Impl"); 2609 return; 2610 } 2611 switch (numTotal) { 2612 case 0: 2613 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind); 2614 break; 2615 case 1: 2616 opClass.addTrait("::mlir::OpTrait::One" + traitKind); 2617 break; 2618 default: 2619 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) + 2620 ">::Impl"); 2621 break; 2622 } 2623 } 2624 2625 void OpEmitter::genTraits() { 2626 // Add region size trait. 2627 unsigned numRegions = op.getNumRegions(); 2628 unsigned numVariadicRegions = op.getNumVariadicRegions(); 2629 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); 2630 2631 // Add result size traits. 2632 int numResults = op.getNumResults(); 2633 int numVariadicResults = op.getNumVariableLengthResults(); 2634 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); 2635 2636 // For single result ops with a known specific type, generate a OneTypedResult 2637 // trait. 2638 if (numResults == 1 && numVariadicResults == 0) { 2639 auto cppName = op.getResults().begin()->constraint.getCPPClassName(); 2640 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); 2641 } 2642 2643 // Add successor size trait. 2644 unsigned numSuccessors = op.getNumSuccessors(); 2645 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); 2646 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); 2647 2648 // Add variadic size trait and normal op traits. 2649 int numOperands = op.getNumOperands(); 2650 int numVariadicOperands = op.getNumVariableLengthOperands(); 2651 2652 // Add operand size trait. 2653 if (numVariadicOperands != 0) { 2654 if (numOperands == numVariadicOperands) 2655 opClass.addTrait("::mlir::OpTrait::VariadicOperands"); 2656 else 2657 opClass.addTrait("::mlir::OpTrait::AtLeastNOperands<" + 2658 Twine(numOperands - numVariadicOperands) + ">::Impl"); 2659 } else { 2660 switch (numOperands) { 2661 case 0: 2662 opClass.addTrait("::mlir::OpTrait::ZeroOperands"); 2663 break; 2664 case 1: 2665 opClass.addTrait("::mlir::OpTrait::OneOperand"); 2666 break; 2667 default: 2668 opClass.addTrait("::mlir::OpTrait::NOperands<" + Twine(numOperands) + 2669 ">::Impl"); 2670 break; 2671 } 2672 } 2673 2674 // The op traits defined internal are ensured that they can be verified 2675 // earlier. 2676 for (const auto &trait : op.getTraits()) { 2677 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 2678 if (opTrait->isStructuralOpTrait()) 2679 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 2680 } 2681 } 2682 2683 // OpInvariants wrapps the verifyInvariants which needs to be run before 2684 // native/interface traits and after all the traits with `StructuralOpTrait`. 2685 opClass.addTrait("::mlir::OpTrait::OpInvariants"); 2686 2687 // Add the native and interface traits. 2688 for (const auto &trait : op.getTraits()) { 2689 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) { 2690 if (!opTrait->isStructuralOpTrait()) 2691 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 2692 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) { 2693 opClass.addTrait(opTrait->getFullyQualifiedTraitName()); 2694 } 2695 } 2696 } 2697 2698 void OpEmitter::genOpNameGetter() { 2699 auto *method = opClass.addStaticMethod<Method::Constexpr>( 2700 "::llvm::StringLiteral", "getOperationName"); 2701 ERROR_IF_PRUNED(method, "getOperationName", op); 2702 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() 2703 << "\");"; 2704 } 2705 2706 void OpEmitter::genOpAsmInterface() { 2707 // If the user only has one results or specifically added the Asm trait, 2708 // then don't generate it for them. We specifically only handle multi result 2709 // operations, because the name of a single result in the common case is not 2710 // interesting(generally 'result'/'output'/etc.). 2711 // TODO: We could also add a flag to allow operations to opt in to this 2712 // generation, even if they only have a single operation. 2713 int numResults = op.getNumResults(); 2714 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait")) 2715 return; 2716 2717 SmallVector<StringRef, 4> resultNames(numResults); 2718 for (int i = 0; i != numResults; ++i) 2719 resultNames[i] = op.getResultName(i); 2720 2721 // Don't add the trait if none of the results have a valid name. 2722 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) 2723 return; 2724 opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); 2725 2726 // Generate the right accessor for the number of results. 2727 auto *method = opClass.addMethod( 2728 "void", "getAsmResultNames", 2729 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn")); 2730 ERROR_IF_PRUNED(method, "getAsmResultNames", op); 2731 auto &body = method->body(); 2732 for (int i = 0; i != numResults; ++i) { 2733 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" 2734 << " if (!llvm::empty(resultGroup" << i << "))\n" 2735 << " setNameFn(*resultGroup" << i << ".begin(), \"" 2736 << resultNames[i] << "\");\n"; 2737 } 2738 } 2739 2740 //===----------------------------------------------------------------------===// 2741 // OpOperandAdaptor emitter 2742 //===----------------------------------------------------------------------===// 2743 2744 namespace { 2745 // Helper class to emit Op operand adaptors to an output stream. Operand 2746 // adaptors are wrappers around ArrayRef<Value> that provide named operand 2747 // getters identical to those defined in the Op. 2748 class OpOperandAdaptorEmitter { 2749 public: 2750 static void 2751 emitDecl(const Operator &op, 2752 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 2753 raw_ostream &os); 2754 static void 2755 emitDef(const Operator &op, 2756 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 2757 raw_ostream &os); 2758 2759 private: 2760 explicit OpOperandAdaptorEmitter( 2761 const Operator &op, 2762 const StaticVerifierFunctionEmitter &staticVerifierEmitter); 2763 2764 // Add verification function. This generates a verify method for the adaptor 2765 // which verifies all the op-independent attribute constraints. 2766 void addVerification(); 2767 2768 // The operation for which to emit an adaptor. 2769 const Operator &op; 2770 2771 // The generated adaptor class. 2772 Class adaptor; 2773 2774 // The emitter containing all of the locally emitted verification functions. 2775 const StaticVerifierFunctionEmitter &staticVerifierEmitter; 2776 2777 // Helper for emitting adaptor code. 2778 OpOrAdaptorHelper emitHelper; 2779 }; 2780 } // namespace 2781 2782 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( 2783 const Operator &op, 2784 const StaticVerifierFunctionEmitter &staticVerifierEmitter) 2785 : op(op), adaptor(op.getAdaptorName()), 2786 staticVerifierEmitter(staticVerifierEmitter), 2787 emitHelper(op, /*emitForOp=*/false) { 2788 adaptor.addField("::mlir::ValueRange", "odsOperands"); 2789 adaptor.addField("::mlir::DictionaryAttr", "odsAttrs"); 2790 adaptor.addField("::mlir::RegionRange", "odsRegions"); 2791 adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName"); 2792 2793 const auto *attrSizedOperands = 2794 op.getTrait("::m::OpTrait::AttrSizedOperandSegments"); 2795 { 2796 SmallVector<MethodParameter> paramList; 2797 paramList.emplace_back("::mlir::ValueRange", "values"); 2798 paramList.emplace_back("::mlir::DictionaryAttr", "attrs", 2799 attrSizedOperands ? "" : "nullptr"); 2800 paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); 2801 auto *constructor = adaptor.addConstructor(std::move(paramList)); 2802 2803 constructor->addMemberInitializer("odsOperands", "values"); 2804 constructor->addMemberInitializer("odsAttrs", "attrs"); 2805 constructor->addMemberInitializer("odsRegions", "regions"); 2806 2807 MethodBody &body = constructor->body(); 2808 body.indent() << "if (odsAttrs)\n"; 2809 body.indent() << formatv( 2810 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n", 2811 op.getOperationName()); 2812 } 2813 2814 { 2815 auto *constructor = 2816 adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); 2817 constructor->addMemberInitializer("odsOperands", "op->getOperands()"); 2818 constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); 2819 constructor->addMemberInitializer("odsRegions", "op->getRegions()"); 2820 constructor->addMemberInitializer("odsOpName", "op->getName()"); 2821 } 2822 2823 { 2824 auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"); 2825 ERROR_IF_PRUNED(m, "getOperands", op); 2826 m->body() << " return odsOperands;"; 2827 } 2828 std::string sizeAttrInit; 2829 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { 2830 sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, 2831 emitHelper.getAttr(operandSegmentAttrName)); 2832 } 2833 generateNamedOperandGetters(op, adaptor, 2834 /*isAdaptor=*/true, sizeAttrInit, 2835 /*rangeType=*/"::mlir::ValueRange", 2836 /*rangeBeginCall=*/"odsOperands.begin()", 2837 /*rangeSizeCall=*/"odsOperands.size()", 2838 /*getOperandCallPattern=*/"odsOperands[{0}]"); 2839 2840 FmtContext fctx; 2841 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); 2842 2843 // Generate named accessor with Attribute return type. 2844 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName, 2845 Attribute attr) { 2846 auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr"); 2847 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); 2848 auto &body = method->body().indent(); 2849 body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n" 2850 << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name), 2851 attr.hasDefaultValue() || attr.isOptional() 2852 ? "dyn_cast_or_null" 2853 : "cast", 2854 attr.getStorageType()); 2855 2856 if (attr.hasDefaultValue()) { 2857 // Use the default value if attribute is not set. 2858 // TODO: this is inefficient, we are recreating the attribute for every 2859 // call. This should be set instead. 2860 std::string defaultValue = std::string( 2861 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); 2862 body << " if (!attr)\n attr = " << defaultValue << ";\n"; 2863 } 2864 body << " return attr;\n"; 2865 }; 2866 2867 { 2868 auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes"); 2869 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); 2870 m->body() << " return odsAttrs;"; 2871 } 2872 for (auto &namedAttr : op.getAttributes()) { 2873 const auto &name = namedAttr.name; 2874 const auto &attr = namedAttr.attr; 2875 if (attr.isDerivedAttr()) 2876 continue; 2877 for (const auto &emitName : op.getGetterNames(name)) { 2878 emitAttrWithStorageType(name, emitName, attr); 2879 emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr); 2880 } 2881 } 2882 2883 unsigned numRegions = op.getNumRegions(); 2884 if (numRegions > 0) { 2885 auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"); 2886 ERROR_IF_PRUNED(m, "Adaptor::getRegions", op); 2887 m->body() << " return odsRegions;"; 2888 } 2889 for (unsigned i = 0; i < numRegions; ++i) { 2890 const auto ®ion = op.getRegion(i); 2891 if (region.name.empty()) 2892 continue; 2893 2894 // Generate the accessors for a variadic region. 2895 for (StringRef name : op.getGetterNames(region.name)) { 2896 if (region.isVariadic()) { 2897 auto *m = adaptor.addMethod("::mlir::RegionRange", name); 2898 ERROR_IF_PRUNED(m, "Adaptor::" + name, op); 2899 m->body() << formatv(" return odsRegions.drop_front({0});", i); 2900 continue; 2901 } 2902 2903 auto *m = adaptor.addMethod("::mlir::Region &", name); 2904 ERROR_IF_PRUNED(m, "Adaptor::" + name, op); 2905 m->body() << formatv(" return *odsRegions[{0}];", i); 2906 } 2907 } 2908 2909 // Add verification function. 2910 addVerification(); 2911 adaptor.finalize(); 2912 } 2913 2914 void OpOperandAdaptorEmitter::addVerification() { 2915 auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify", 2916 MethodParameter("::mlir::Location", "loc")); 2917 ERROR_IF_PRUNED(method, "verify", op); 2918 auto &body = method->body(); 2919 2920 FmtContext verifyCtx; 2921 populateSubstitutions(emitHelper, verifyCtx); 2922 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); 2923 2924 body << " return ::mlir::success();"; 2925 } 2926 2927 void OpOperandAdaptorEmitter::emitDecl( 2928 const Operator &op, 2929 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 2930 raw_ostream &os) { 2931 OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); 2932 } 2933 2934 void OpOperandAdaptorEmitter::emitDef( 2935 const Operator &op, 2936 const StaticVerifierFunctionEmitter &staticVerifierEmitter, 2937 raw_ostream &os) { 2938 OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); 2939 } 2940 2941 // Emits the opcode enum and op classes. 2942 static void emitOpClasses(const RecordKeeper &recordKeeper, 2943 const std::vector<Record *> &defs, raw_ostream &os, 2944 bool emitDecl) { 2945 // First emit forward declaration for each class, this allows them to refer 2946 // to each others in traits for example. 2947 if (emitDecl) { 2948 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; 2949 os << "#undef GET_OP_FWD_DEFINES\n"; 2950 for (auto *def : defs) { 2951 Operator op(*def); 2952 NamespaceEmitter emitter(os, op.getCppNamespace()); 2953 os << "class " << op.getCppClassName() << ";\n"; 2954 } 2955 os << "#endif\n\n"; 2956 } 2957 2958 IfDefScope scope("GET_OP_CLASSES", os); 2959 if (defs.empty()) 2960 return; 2961 2962 // Generate all of the locally instantiated methods first. 2963 StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); 2964 os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); 2965 staticVerifierEmitter.emitOpConstraints(defs, emitDecl); 2966 2967 for (auto *def : defs) { 2968 Operator op(*def); 2969 if (emitDecl) { 2970 { 2971 NamespaceEmitter emitter(os, op.getCppNamespace()); 2972 os << formatv(opCommentHeader, op.getQualCppClassName(), 2973 "declarations"); 2974 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); 2975 OpEmitter::emitDecl(op, os, staticVerifierEmitter); 2976 } 2977 // Emit the TypeID explicit specialization to have a single definition. 2978 if (!op.getCppNamespace().empty()) 2979 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() 2980 << "::" << op.getCppClassName() << ")\n\n"; 2981 } else { 2982 { 2983 NamespaceEmitter emitter(os, op.getCppNamespace()); 2984 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); 2985 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os); 2986 OpEmitter::emitDef(op, os, staticVerifierEmitter); 2987 } 2988 // Emit the TypeID explicit specialization to have a single definition. 2989 if (!op.getCppNamespace().empty()) 2990 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() 2991 << "::" << op.getCppClassName() << ")\n\n"; 2992 } 2993 } 2994 } 2995 2996 // Emits a comma-separated list of the ops. 2997 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) { 2998 IfDefScope scope("GET_OP_LIST", os); 2999 3000 interleave( 3001 // TODO: We are constructing the Operator wrapper instance just for 3002 // getting it's qualified class name here. Reduce the overhead by having a 3003 // lightweight version of Operator class just for that purpose. 3004 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, 3005 [&os]() { os << ",\n"; }); 3006 } 3007 3008 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { 3009 emitSourceFileHeader("Op Declarations", os); 3010 3011 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper); 3012 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); 3013 3014 return false; 3015 } 3016 3017 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { 3018 emitSourceFileHeader("Op Definitions", os); 3019 3020 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper); 3021 emitOpList(defs, os); 3022 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); 3023 3024 return false; 3025 } 3026 3027 static mlir::GenRegistration 3028 genOpDecls("gen-op-decls", "Generate op declarations", 3029 [](const RecordKeeper &records, raw_ostream &os) { 3030 return emitOpDecls(records, os); 3031 }); 3032 3033 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", 3034 [](const RecordKeeper &records, 3035 raw_ostream &os) { 3036 return emitOpDefs(records, os); 3037 }); 3038