1 //===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpDefinitionsGen uses the description of operations to generate C++
10 // definitions for ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "OpClass.h"
15 #include "OpFormatGen.h"
16 #include "OpGenHelpers.h"
17 #include "mlir/TableGen/Class.h"
18 #include "mlir/TableGen/CodeGenHelpers.h"
19 #include "mlir/TableGen/Format.h"
20 #include "mlir/TableGen/GenInfo.h"
21 #include "mlir/TableGen/Interfaces.h"
22 #include "mlir/TableGen/Operator.h"
23 #include "mlir/TableGen/SideEffects.h"
24 #include "mlir/TableGen/Trait.h"
25 #include "llvm/ADT/MapVector.h"
26 #include "llvm/ADT/Sequence.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/Signals.h"
32 #include "llvm/TableGen/Error.h"
33 #include "llvm/TableGen/Record.h"
34 #include "llvm/TableGen/TableGenBackend.h"
35
36 #define DEBUG_TYPE "mlir-tblgen-opdefgen"
37
38 using namespace llvm;
39 using namespace mlir;
40 using namespace mlir::tblgen;
41
42 static const char *const tblgenNamePrefix = "tblgen_";
43 static const char *const generatedArgName = "odsArg";
44 static const char *const odsBuilder = "odsBuilder";
45 static const char *const builderOpState = "odsState";
46
47 /// The names of the implicit attributes that contain variadic operand and
48 /// result segment sizes.
49 static const char *const operandSegmentAttrName = "operand_segment_sizes";
50 static const char *const resultSegmentAttrName = "result_segment_sizes";
51
52 /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange
53 /// lookup.
54 ///
55 /// {0}: Code snippet to get the attribute's name or identifier.
56 /// {1}: The lower bound on the sorted subrange.
57 /// {2}: The upper bound on the sorted subrange.
58 /// {3}: Code snippet to get the array of named attributes.
59 /// {4}: "Named" to get the named attribute.
60 static const char *const subrangeGetAttr =
61 "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - "
62 "{2}, {0})";
63
64 /// The logic to calculate the actual value range for a declared operand/result
65 /// of an op with variadic operands/results. Note that this logic is not for
66 /// general use; it assumes all variadic operands/results must have the same
67 /// number of values.
68 ///
69 /// {0}: The list of whether each declared operand/result is variadic.
70 /// {1}: The total number of non-variadic operands/results.
71 /// {2}: The total number of variadic operands/results.
72 /// {3}: The total number of actual values.
73 /// {4}: "operand" or "result".
74 static const char *const sameVariadicSizeValueRangeCalcCode = R"(
75 bool isVariadic[] = {{{0}};
76 int prevVariadicCount = 0;
77 for (unsigned i = 0; i < index; ++i)
78 if (isVariadic[i]) ++prevVariadicCount;
79
80 // Calculate how many dynamic values a static variadic {4} corresponds to.
81 // This assumes all static variadic {4}s have the same dynamic value count.
82 int variadicSize = ({3} - {1}) / {2};
83 // `index` passed in as the parameter is the static index which counts each
84 // {4} (variadic or not) as size 1. So here for each previous static variadic
85 // {4}, we need to offset by (variadicSize - 1) to get where the dynamic
86 // value pack for this static {4} starts.
87 int start = index + (variadicSize - 1) * prevVariadicCount;
88 int size = isVariadic[index] ? variadicSize : 1;
89 return {{start, size};
90 )";
91
92 /// The logic to calculate the actual value range for a declared operand/result
93 /// of an op with variadic operands/results. Note that this logic is assumes
94 /// the op has an attribute specifying the size of each operand/result segment
95 /// (variadic or not).
96 static const char *const attrSizedSegmentValueRangeCalcCode = R"(
97 const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin<uint32_t>();
98 if (sizeAttr.isSplat())
99 return {*sizeAttrValueIt * index, *sizeAttrValueIt};
100
101 unsigned start = 0;
102 for (unsigned i = 0; i < index; ++i)
103 start += sizeAttrValueIt[i];
104 return {start, sizeAttrValueIt[index]};
105 )";
106 /// The code snippet to initialize the sizes for the value range calculation.
107 ///
108 /// {0}: The code to get the attribute.
109 static const char *const adapterSegmentSizeAttrInitCode = R"(
110 assert(odsAttrs && "missing segment size attribute for op");
111 auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
112 )";
113 /// The code snippet to initialize the sizes for the value range calculation.
114 ///
115 /// {0}: The code to get the attribute.
116 static const char *const opSegmentSizeAttrInitCode = R"(
117 auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>();
118 )";
119
120 /// The logic to calculate the actual value range for a declared operand
121 /// of an op with variadic of variadic operands within the OpAdaptor.
122 ///
123 /// {0}: The name of the segment attribute.
124 /// {1}: The index of the main operand.
125 static const char *const variadicOfVariadicAdaptorCalcCode = R"(
126 auto tblgenTmpOperands = getODSOperands({1});
127 auto sizeAttrValues = {0}().getValues<uint32_t>();
128 auto sizeAttrIt = sizeAttrValues.begin();
129
130 ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
131 for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{
132 tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt));
133 tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt);
134 }
135 return tblgenTmpOperandGroups;
136 )";
137
138 /// The logic to build a range of either operand or result values.
139 ///
140 /// {0}: The begin iterator of the actual values.
141 /// {1}: The call to generate the start and length of the value range.
142 static const char *const valueRangeReturnCode = R"(
143 auto valueRange = {1};
144 return {{std::next({0}, valueRange.first),
145 std::next({0}, valueRange.first + valueRange.second)};
146 )";
147
148 /// A header for indicating code sections.
149 ///
150 /// {0}: Some text, or a class name.
151 /// {1}: Some text.
152 static const char *const opCommentHeader = R"(
153 //===----------------------------------------------------------------------===//
154 // {0} {1}
155 //===----------------------------------------------------------------------===//
156
157 )";
158
159 //===----------------------------------------------------------------------===//
160 // Utility structs and functions
161 //===----------------------------------------------------------------------===//
162
163 // Replaces all occurrences of `match` in `str` with `substitute`.
replaceAllSubstrs(std::string str,const std::string & match,const std::string & substitute)164 static std::string replaceAllSubstrs(std::string str, const std::string &match,
165 const std::string &substitute) {
166 std::string::size_type scanLoc = 0, matchLoc = std::string::npos;
167 while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) {
168 str = str.replace(matchLoc, match.size(), substitute);
169 scanLoc = matchLoc + substitute.size();
170 }
171 return str;
172 }
173
174 // Returns whether the record has a value of the given name that can be returned
175 // via getValueAsString.
hasStringAttribute(const Record & record,StringRef fieldName)176 static inline bool hasStringAttribute(const Record &record,
177 StringRef fieldName) {
178 auto *valueInit = record.getValueInit(fieldName);
179 return isa<StringInit>(valueInit);
180 }
181
getArgumentName(const Operator & op,int index)182 static std::string getArgumentName(const Operator &op, int index) {
183 const auto &operand = op.getOperand(index);
184 if (!operand.name.empty())
185 return std::string(operand.name);
186 return std::string(formatv("{0}_{1}", generatedArgName, index));
187 }
188
189 // Returns true if we can use unwrapped value for the given `attr` in builders.
canUseUnwrappedRawValue(const tblgen::Attribute & attr)190 static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) {
191 return attr.getReturnType() != attr.getStorageType() &&
192 // We need to wrap the raw value into an attribute in the builder impl
193 // so we need to make sure that the attribute specifies how to do that.
194 !attr.getConstBuilderTemplate().empty();
195 }
196
197 namespace {
198 /// Metadata on a registered attribute. Given that attributes are stored in
199 /// sorted order on operations, we can use information from ODS to deduce the
200 /// number of required attributes less and and greater than each attribute,
201 /// allowing us to search only a subrange of the attributes in ODS-generated
202 /// getters.
203 struct AttributeMetadata {
204 /// The attribute name.
205 StringRef attrName;
206 /// Whether the attribute is required.
207 bool isRequired;
208 /// The ODS attribute constraint. Not present for implicit attributes.
209 Optional<Attribute> constraint;
210 /// The number of required attributes less than this attribute.
211 unsigned lowerBound = 0;
212 /// The number of required attributes greater than this attribute.
213 unsigned upperBound = 0;
214 };
215
216 /// Helper class to select between OpAdaptor and Op code templates.
217 class OpOrAdaptorHelper {
218 public:
OpOrAdaptorHelper(const Operator & op,bool emitForOp)219 OpOrAdaptorHelper(const Operator &op, bool emitForOp)
220 : op(op), emitForOp(emitForOp) {
221 computeAttrMetadata();
222 }
223
224 /// Object that wraps a functor in a stream operator for interop with
225 /// llvm::formatv.
226 class Formatter {
227 public:
228 template <typename Functor>
Formatter(Functor && func)229 Formatter(Functor &&func) : func(std::forward<Functor>(func)) {}
230
str() const231 std::string str() const {
232 std::string result;
233 llvm::raw_string_ostream os(result);
234 os << *this;
235 return os.str();
236 }
237
238 private:
239 std::function<raw_ostream &(raw_ostream &)> func;
240
operator <<(raw_ostream & os,const Formatter & fmt)241 friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) {
242 return fmt.func(os);
243 }
244 };
245
246 // Generate code for getting an attribute.
getAttr(StringRef attrName,bool isNamed=false) const247 Formatter getAttr(StringRef attrName, bool isNamed = false) const {
248 assert(attrMetadata.count(attrName) && "expected attribute metadata");
249 return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & {
250 const AttributeMetadata &attr = attrMetadata.find(attrName)->second;
251 return os << formatv(subrangeGetAttr, getAttrName(attrName),
252 attr.lowerBound, attr.upperBound, getAttrRange(),
253 isNamed ? "Named" : "");
254 };
255 }
256
257 // Generate code for getting the name of an attribute.
getAttrName(StringRef attrName) const258 Formatter getAttrName(StringRef attrName) const {
259 return [this, attrName](raw_ostream &os) -> raw_ostream & {
260 if (emitForOp)
261 return os << op.getGetterName(attrName) << "AttrName()";
262 return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(),
263 op.getGetterName(attrName));
264 };
265 }
266
267 // Get the code snippet for getting the named attribute range.
getAttrRange() const268 StringRef getAttrRange() const {
269 return emitForOp ? "(*this)->getAttrs()" : "odsAttrs";
270 }
271
272 // Get the prefix code for emitting an error.
emitErrorPrefix() const273 Formatter emitErrorPrefix() const {
274 return [this](raw_ostream &os) -> raw_ostream & {
275 if (emitForOp)
276 return os << "emitOpError(";
277 return os << formatv("emitError(loc, \"'{0}' op \"",
278 op.getOperationName());
279 };
280 }
281
282 // Get the call to get an operand or segment of operands.
getOperand(unsigned index) const283 Formatter getOperand(unsigned index) const {
284 return [this, index](raw_ostream &os) -> raw_ostream & {
285 return os << formatv(op.getOperand(index).isVariadic()
286 ? "this->getODSOperands({0})"
287 : "(*this->getODSOperands({0}).begin())",
288 index);
289 };
290 }
291
292 // Get the call to get a result of segment of results.
getResult(unsigned index) const293 Formatter getResult(unsigned index) const {
294 return [this, index](raw_ostream &os) -> raw_ostream & {
295 if (!emitForOp)
296 return os << "<no results should be generated>";
297 return os << formatv(op.getResult(index).isVariadic()
298 ? "this->getODSResults({0})"
299 : "(*this->getODSResults({0}).begin())",
300 index);
301 };
302 }
303
304 // Return whether an op instance is available.
isEmittingForOp() const305 bool isEmittingForOp() const { return emitForOp; }
306
307 // Return the ODS operation wrapper.
getOp() const308 const Operator &getOp() const { return op; }
309
310 // Get the attribute metadata sorted by name.
getAttrMetadata() const311 const llvm::MapVector<StringRef, AttributeMetadata> &getAttrMetadata() const {
312 return attrMetadata;
313 }
314
315 private:
316 // Compute the attribute metadata.
317 void computeAttrMetadata();
318
319 // The operation ODS wrapper.
320 const Operator &op;
321 // True if code is being generate for an op. False for an adaptor.
322 const bool emitForOp;
323
324 // The attribute metadata, mapped by name.
325 llvm::MapVector<StringRef, AttributeMetadata> attrMetadata;
326 // The number of required attributes.
327 unsigned numRequired;
328 };
329
330 } // namespace
331
computeAttrMetadata()332 void OpOrAdaptorHelper::computeAttrMetadata() {
333 // Enumerate the attribute names of this op, ensuring the attribute names are
334 // unique in case implicit attributes are explicitly registered.
335 for (const NamedAttribute &namedAttr : op.getAttributes()) {
336 Attribute attr = namedAttr.attr;
337 bool isOptional =
338 attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr();
339 attrMetadata.insert(
340 {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}});
341 }
342 // Include key attributes from several traits as implicitly registered.
343 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
344 attrMetadata.insert(
345 {operandSegmentAttrName,
346 AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true,
347 /*attr=*/llvm::None}});
348 }
349 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
350 attrMetadata.insert(
351 {resultSegmentAttrName,
352 AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true,
353 /*attr=*/llvm::None}});
354 }
355
356 // Store the metadata in sorted order.
357 SmallVector<AttributeMetadata> sortedAttrMetadata =
358 llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector()));
359 llvm::sort(sortedAttrMetadata,
360 [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) {
361 return lhs.attrName < rhs.attrName;
362 });
363
364 // Compute the subrange bounds for each attribute.
365 numRequired = 0;
366 for (AttributeMetadata &attr : sortedAttrMetadata) {
367 attr.lowerBound = numRequired;
368 numRequired += attr.isRequired;
369 };
370 for (AttributeMetadata &attr : sortedAttrMetadata)
371 attr.upperBound = numRequired - attr.lowerBound - attr.isRequired;
372
373 // Store the results back into the map.
374 for (const AttributeMetadata &attr : sortedAttrMetadata)
375 attrMetadata.insert({attr.attrName, attr});
376 }
377
378 //===----------------------------------------------------------------------===//
379 // Op emitter
380 //===----------------------------------------------------------------------===//
381
382 namespace {
383 // Helper class to emit a record into the given output stream.
384 class OpEmitter {
385 public:
386 static void
387 emitDecl(const Operator &op, raw_ostream &os,
388 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
389 static void
390 emitDef(const Operator &op, raw_ostream &os,
391 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
392
393 private:
394 OpEmitter(const Operator &op,
395 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
396
397 void emitDecl(raw_ostream &os);
398 void emitDef(raw_ostream &os);
399
400 // Generate methods for accessing the attribute names of this operation.
401 void genAttrNameGetters();
402
403 // Generates the OpAsmOpInterface for this operation if possible.
404 void genOpAsmInterface();
405
406 // Generates the `getOperationName` method for this op.
407 void genOpNameGetter();
408
409 // Generates getters for the attributes.
410 void genAttrGetters();
411
412 // Generates setter for the attributes.
413 void genAttrSetters();
414
415 // Generates removers for optional attributes.
416 void genOptionalAttrRemovers();
417
418 // Generates getters for named operands.
419 void genNamedOperandGetters();
420
421 // Generates setters for named operands.
422 void genNamedOperandSetters();
423
424 // Generates getters for named results.
425 void genNamedResultGetters();
426
427 // Generates getters for named regions.
428 void genNamedRegionGetters();
429
430 // Generates getters for named successors.
431 void genNamedSuccessorGetters();
432
433 // Generates the method to populate default attributes.
434 void genPopulateDefaultAttributes();
435
436 // Generates builder methods for the operation.
437 void genBuilder();
438
439 // Generates the build() method that takes each operand/attribute
440 // as a stand-alone parameter.
441 void genSeparateArgParamBuilder();
442
443 // Generates the build() method that takes each operand/attribute as a
444 // stand-alone parameter. The generated build() method uses first operand's
445 // type as all results' types.
446 void genUseOperandAsResultTypeSeparateParamBuilder();
447
448 // Generates the build() method that takes all operands/attributes
449 // collectively as one parameter. The generated build() method uses first
450 // operand's type as all results' types.
451 void genUseOperandAsResultTypeCollectiveParamBuilder();
452
453 // Generates the build() method that takes aggregate operands/attributes
454 // parameters. This build() method uses inferred types as result types.
455 // Requires: The type needs to be inferable via InferTypeOpInterface.
456 void genInferredTypeCollectiveParamBuilder();
457
458 // Generates the build() method that takes each operand/attribute as a
459 // stand-alone parameter. The generated build() method uses first attribute's
460 // type as all result's types.
461 void genUseAttrAsResultTypeBuilder();
462
463 // Generates the build() method that takes all result types collectively as
464 // one parameter. Similarly for operands and attributes.
465 void genCollectiveParamBuilder();
466
467 // The kind of parameter to generate for result types in builders.
468 enum class TypeParamKind {
469 None, // No result type in parameter list.
470 Separate, // A separate parameter for each result type.
471 Collective, // An ArrayRef<Type> for all result types.
472 };
473
474 // The kind of parameter to generate for attributes in builders.
475 enum class AttrParamKind {
476 WrappedAttr, // A wrapped MLIR Attribute instance.
477 UnwrappedValue, // A raw value without MLIR Attribute wrapper.
478 };
479
480 // Builds the parameter list for build() method of this op. This method writes
481 // to `paramList` the comma-separated parameter list and updates
482 // `resultTypeNames` with the names for parameters for specifying result
483 // types. `inferredAttributes` is populated with any attributes that are
484 // elided from the build list. The given `typeParamKind` and `attrParamKind`
485 // controls how result types and attributes are placed in the parameter list.
486 void buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
487 llvm::StringSet<> &inferredAttributes,
488 SmallVectorImpl<std::string> &resultTypeNames,
489 TypeParamKind typeParamKind,
490 AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
491
492 // Adds op arguments and regions into operation state for build() methods.
493 void
494 genCodeForAddingArgAndRegionForBuilder(MethodBody &body,
495 llvm::StringSet<> &inferredAttributes,
496 bool isRawValueAttr = false);
497
498 // Generates canonicalizer declaration for the operation.
499 void genCanonicalizerDecls();
500
501 // Generates the folder declaration for the operation.
502 void genFolderDecls();
503
504 // Generates the parser for the operation.
505 void genParser();
506
507 // Generates the printer for the operation.
508 void genPrinter();
509
510 // Generates verify method for the operation.
511 void genVerifier();
512
513 // Generates custom verify methods for the operation.
514 void genCustomVerifier();
515
516 // Generates verify statements for operands and results in the operation.
517 // The generated code will be attached to `body`.
518 void genOperandResultVerifier(MethodBody &body,
519 Operator::const_value_range values,
520 StringRef valueKind);
521
522 // Generates verify statements for regions in the operation.
523 // The generated code will be attached to `body`.
524 void genRegionVerifier(MethodBody &body);
525
526 // Generates verify statements for successors in the operation.
527 // The generated code will be attached to `body`.
528 void genSuccessorVerifier(MethodBody &body);
529
530 // Generates the traits used by the object.
531 void genTraits();
532
533 // Generate the OpInterface methods for all interfaces.
534 void genOpInterfaceMethods();
535
536 // Generate op interface methods for the given interface.
537 void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait);
538
539 // Generate op interface method for the given interface method. If
540 // 'declaration' is true, generates a declaration, else a definition.
541 Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
542 bool declaration = true);
543
544 // Generate the side effect interface methods.
545 void genSideEffectInterfaceMethods();
546
547 // Generate the type inference interface methods.
548 void genTypeInterfaceMethods();
549
550 private:
551 // The TableGen record for this op.
552 // TODO: OpEmitter should not have a Record directly,
553 // it should rather go through the Operator for better abstraction.
554 const Record &def;
555
556 // The wrapper operator class for querying information from this op.
557 const Operator &op;
558
559 // The C++ code builder for this op
560 OpClass opClass;
561
562 // The format context for verification code generation.
563 FmtContext verifyCtx;
564
565 // The emitter containing all of the locally emitted verification functions.
566 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
567
568 // Helper for emitting op code.
569 OpOrAdaptorHelper emitHelper;
570 };
571
572 } // namespace
573
574 // Populate the format context `ctx` with substitutions of attributes, operands
575 // and results.
populateSubstitutions(const OpOrAdaptorHelper & emitHelper,FmtContext & ctx)576 static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper,
577 FmtContext &ctx) {
578 // Populate substitutions for attributes.
579 auto &op = emitHelper.getOp();
580 for (const auto &namedAttr : op.getAttributes())
581 ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str());
582
583 // Populate substitutions for named operands.
584 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
585 auto &value = op.getOperand(i);
586 if (!value.name.empty())
587 ctx.addSubst(value.name, emitHelper.getOperand(i).str());
588 }
589
590 // Populate substitutions for results.
591 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
592 auto &value = op.getResult(i);
593 if (!value.name.empty())
594 ctx.addSubst(value.name, emitHelper.getResult(i).str());
595 }
596 }
597
598 /// Generate verification on native traits requiring attributes.
genNativeTraitAttrVerifier(MethodBody & body,const OpOrAdaptorHelper & emitHelper)599 static void genNativeTraitAttrVerifier(MethodBody &body,
600 const OpOrAdaptorHelper &emitHelper) {
601 // Check that the variadic segment sizes attribute exists and contains the
602 // expected number of elements.
603 //
604 // {0}: Attribute name.
605 // {1}: Expected number of elements.
606 // {2}: "operand" or "result".
607 // {3}: Emit error prefix.
608 const char *const checkAttrSizedValueSegmentsCode = R"(
609 {
610 auto sizeAttr = tblgen_{0}.cast<::mlir::DenseIntElementsAttr>();
611 auto numElements =
612 sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements();
613 if (numElements != {1})
614 return {3}"'{0}' attribute for specifying {2} segments must have {1} "
615 "elements, but got ") << numElements;
616 }
617 )";
618
619 // Verify a few traits first so that we can use getODSOperands() and
620 // getODSResults() in the rest of the verifier.
621 auto &op = emitHelper.getOp();
622 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
623 body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName,
624 op.getNumOperands(), "operand",
625 emitHelper.emitErrorPrefix());
626 }
627 if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) {
628 body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName,
629 op.getNumResults(), "result", emitHelper.emitErrorPrefix());
630 }
631 }
632
633 // Generate attribute verification. If an op instance is not available, then
634 // attribute checks that require one will not be emitted.
635 //
636 // Attribute verification is performed as follows:
637 //
638 // 1. Verify that all required attributes are present in sorted order. This
639 // ensures that we can use subrange lookup even with potentially missing
640 // attributes.
641 // 2. Verify native trait attributes so that other attributes may call methods
642 // that depend on the validity of these attributes, e.g. segment size attributes
643 // and operand or result getters.
644 // 3. Verify the constraints on all present attributes.
genAttributeVerifier(const OpOrAdaptorHelper & emitHelper,FmtContext & ctx,MethodBody & body,const StaticVerifierFunctionEmitter & staticVerifierEmitter)645 static void genAttributeVerifier(
646 const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body,
647 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
648 if (emitHelper.getAttrMetadata().empty())
649 return;
650
651 // Verify the attribute if it is present. This assumes that default values
652 // are valid. This code snippet pastes the condition inline.
653 //
654 // TODO: verify the default value is valid (perhaps in debug mode only).
655 //
656 // {0}: Attribute variable name.
657 // {1}: Attribute condition code.
658 // {2}: Emit error prefix.
659 // {3}: Attribute name.
660 // {4}: Attribute/constraint description.
661 const char *const verifyAttrInline = R"(
662 if ({0} && !({1}))
663 return {2}"attribute '{3}' failed to satisfy constraint: {4}");
664 )";
665 // Verify the attribute using a uniqued constraint. Can only be used within
666 // the context of an op.
667 //
668 // {0}: Unique constraint name.
669 // {1}: Attribute variable name.
670 // {2}: Attribute name.
671 const char *const verifyAttrUnique = R"(
672 if (::mlir::failed({0}(*this, {1}, "{2}")))
673 return ::mlir::failure();
674 )";
675
676 // Traverse the array until the required attribute is found. Return an error
677 // if the traversal reached the end.
678 //
679 // {0}: Code to get the name of the attribute.
680 // {1}: The emit error prefix.
681 // {2}: The name of the attribute.
682 const char *const findRequiredAttr = R"(while (true) {{
683 if (namedAttrIt == namedAttrRange.end())
684 return {1}"requires attribute '{2}'");
685 if (namedAttrIt->getName() == {0}) {{
686 tblgen_{2} = namedAttrIt->getValue();
687 break;
688 })";
689
690 // Emit a check to see if the iteration has encountered an optional attribute.
691 //
692 // {0}: Code to get the name of the attribute.
693 // {1}: The name of the attribute.
694 const char *const checkOptionalAttr = R"(
695 else if (namedAttrIt->getName() == {0}) {{
696 tblgen_{1} = namedAttrIt->getValue();
697 })";
698
699 // Emit the start of the loop for checking trailing attributes.
700 const char *const checkTrailingAttrs = R"(while (true) {
701 if (namedAttrIt == namedAttrRange.end()) {
702 break;
703 })";
704
705 // Return true if a verifier can be emitted for the attribute: it is not a
706 // derived attribute, it has a predicate, its condition is not empty, and, for
707 // adaptors, the condition does not reference the op.
708 const auto canEmitVerifier = [&](Attribute attr) {
709 if (attr.isDerivedAttr())
710 return false;
711 Pred pred = attr.getPredicate();
712 if (pred.isNull())
713 return false;
714 std::string condition = pred.getCondition();
715 return !condition.empty() && (!StringRef(condition).contains("$_op") ||
716 emitHelper.isEmittingForOp());
717 };
718
719 // Emit the verifier for the attribute.
720 const auto emitVerifier = [&](Attribute attr, StringRef attrName,
721 StringRef varName) {
722 std::string condition = attr.getPredicate().getCondition();
723
724 Optional<StringRef> constraintFn;
725 if (emitHelper.isEmittingForOp() &&
726 (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) {
727 body << formatv(verifyAttrUnique, *constraintFn, varName, attrName);
728 } else {
729 body << formatv(verifyAttrInline, varName,
730 tgfmt(condition, &ctx.withSelf(varName)),
731 emitHelper.emitErrorPrefix(), attrName,
732 escapeString(attr.getSummary()));
733 }
734 };
735
736 // Prefix variables with `tblgen_` to avoid hiding the attribute accessor.
737 const auto getVarName = [&](StringRef attrName) {
738 return (tblgenNamePrefix + attrName).str();
739 };
740
741 body.indent() << formatv("auto namedAttrRange = {0};\n",
742 emitHelper.getAttrRange());
743 body << "auto namedAttrIt = namedAttrRange.begin();\n";
744
745 // Iterate over the attributes in sorted order. Keep track of the optional
746 // attributes that may be encountered along the way.
747 SmallVector<const AttributeMetadata *> optionalAttrs;
748 for (const std::pair<StringRef, AttributeMetadata> &it :
749 emitHelper.getAttrMetadata()) {
750 const AttributeMetadata &metadata = it.second;
751 if (!metadata.isRequired) {
752 optionalAttrs.push_back(&metadata);
753 continue;
754 }
755
756 body << formatv("::mlir::Attribute {0};\n", getVarName(it.first));
757 for (const AttributeMetadata *optional : optionalAttrs) {
758 body << formatv("::mlir::Attribute {0};\n",
759 getVarName(optional->attrName));
760 }
761 body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first),
762 emitHelper.emitErrorPrefix(), it.first);
763 for (const AttributeMetadata *optional : optionalAttrs) {
764 body << formatv(checkOptionalAttr,
765 emitHelper.getAttrName(optional->attrName),
766 optional->attrName);
767 }
768 body << "\n ++namedAttrIt;\n}\n";
769 optionalAttrs.clear();
770 }
771 // Get trailing optional attributes.
772 if (!optionalAttrs.empty()) {
773 for (const AttributeMetadata *optional : optionalAttrs) {
774 body << formatv("::mlir::Attribute {0};\n",
775 getVarName(optional->attrName));
776 }
777 body << checkTrailingAttrs;
778 for (const AttributeMetadata *optional : optionalAttrs) {
779 body << formatv(checkOptionalAttr,
780 emitHelper.getAttrName(optional->attrName),
781 optional->attrName);
782 }
783 body << "\n ++namedAttrIt;\n}\n";
784 }
785 body.unindent();
786
787 // Emit the checks for segment attributes first so that the other constraints
788 // can call operand and result getters.
789 genNativeTraitAttrVerifier(body, emitHelper);
790
791 for (const auto &namedAttr : emitHelper.getOp().getAttributes())
792 if (canEmitVerifier(namedAttr.attr))
793 emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name));
794 }
795
796 /// Op extra class definitions have a `$cppClass` substitution that is to be
797 /// replaced by the C++ class name.
formatExtraDefinitions(const Operator & op)798 static std::string formatExtraDefinitions(const Operator &op) {
799 FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName());
800 return tgfmt(op.getExtraClassDefinition(), &ctx).str();
801 }
802
OpEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)803 OpEmitter::OpEmitter(const Operator &op,
804 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
805 : def(op.getDef()), op(op),
806 opClass(op.getCppClassName(), op.getExtraClassDeclaration(),
807 formatExtraDefinitions(op)),
808 staticVerifierEmitter(staticVerifierEmitter),
809 emitHelper(op, /*emitForOp=*/true) {
810 verifyCtx.withOp("(*this->getOperation())");
811 verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
812
813 genTraits();
814
815 // Generate C++ code for various op methods. The order here determines the
816 // methods in the generated file.
817 genAttrNameGetters();
818 genOpAsmInterface();
819 genOpNameGetter();
820 genNamedOperandGetters();
821 genNamedOperandSetters();
822 genNamedResultGetters();
823 genNamedRegionGetters();
824 genNamedSuccessorGetters();
825 genAttrGetters();
826 genAttrSetters();
827 genOptionalAttrRemovers();
828 genBuilder();
829 genPopulateDefaultAttributes();
830 genParser();
831 genPrinter();
832 genVerifier();
833 genCustomVerifier();
834 genCanonicalizerDecls();
835 genFolderDecls();
836 genTypeInterfaceMethods();
837 genOpInterfaceMethods();
838 generateOpFormat(op, opClass);
839 genSideEffectInterfaceMethods();
840 }
emitDecl(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)841 void OpEmitter::emitDecl(
842 const Operator &op, raw_ostream &os,
843 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
844 OpEmitter(op, staticVerifierEmitter).emitDecl(os);
845 }
846
emitDef(const Operator & op,raw_ostream & os,const StaticVerifierFunctionEmitter & staticVerifierEmitter)847 void OpEmitter::emitDef(
848 const Operator &op, raw_ostream &os,
849 const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
850 OpEmitter(op, staticVerifierEmitter).emitDef(os);
851 }
852
emitDecl(raw_ostream & os)853 void OpEmitter::emitDecl(raw_ostream &os) {
854 opClass.finalize();
855 opClass.writeDeclTo(os);
856 }
857
emitDef(raw_ostream & os)858 void OpEmitter::emitDef(raw_ostream &os) {
859 opClass.finalize();
860 opClass.writeDefTo(os);
861 }
862
errorIfPruned(size_t line,Method * m,const Twine & methodName,const Operator & op)863 static void errorIfPruned(size_t line, Method *m, const Twine &methodName,
864 const Operator &op) {
865 if (m)
866 return;
867 PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" +
868 methodName + "` for " +
869 op.getOperationName() + " (from line " +
870 Twine(line) + ")");
871 }
872
873 #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O)
874
genAttrNameGetters()875 void OpEmitter::genAttrNameGetters() {
876 const llvm::MapVector<StringRef, AttributeMetadata> &attributes =
877 emitHelper.getAttrMetadata();
878
879 // Emit the getAttributeNames method.
880 {
881 auto *method = opClass.addStaticInlineMethod(
882 "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames");
883 ERROR_IF_PRUNED(method, "getAttributeNames", op);
884 auto &body = method->body();
885 if (attributes.empty()) {
886 body << " return {};";
887 // Nothing else to do if there are no registered attributes. Exit early.
888 return;
889 }
890 body << " static ::llvm::StringRef attrNames[] = {";
891 llvm::interleaveComma(llvm::make_first_range(attributes), body,
892 [&](StringRef attrName) {
893 body << "::llvm::StringRef(\"" << attrName << "\")";
894 });
895 body << "};\n return ::llvm::makeArrayRef(attrNames);";
896 }
897
898 // Emit the getAttributeNameForIndex methods.
899 {
900 auto *method = opClass.addInlineMethod<Method::Private>(
901 "::mlir::StringAttr", "getAttributeNameForIndex",
902 MethodParameter("unsigned", "index"));
903 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
904 method->body()
905 << " return getAttributeNameForIndex((*this)->getName(), index);";
906 }
907 {
908 auto *method = opClass.addStaticInlineMethod<Method::Private>(
909 "::mlir::StringAttr", "getAttributeNameForIndex",
910 MethodParameter("::mlir::OperationName", "name"),
911 MethodParameter("unsigned", "index"));
912 ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op);
913
914 const char *const getAttrName = R"(
915 assert(index < {0} && "invalid attribute index");
916 return name.getRegisteredInfo()->getAttributeNames()[index];
917 )";
918 method->body() << formatv(getAttrName, attributes.size());
919 }
920
921 // Generate the <attr>AttrName methods, that expose the attribute names to
922 // users.
923 const char *attrNameMethodBody = " return getAttributeNameForIndex({0});";
924 for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) {
925 for (StringRef name : op.getGetterNames(attrIt.value())) {
926 std::string methodName = (name + "AttrName").str();
927
928 // Generate the non-static variant.
929 {
930 auto *method =
931 opClass.addInlineMethod("::mlir::StringAttr", methodName);
932 ERROR_IF_PRUNED(method, methodName, op);
933 method->body() << llvm::formatv(attrNameMethodBody, attrIt.index());
934 }
935
936 // Generate the static variant.
937 {
938 auto *method = opClass.addStaticInlineMethod(
939 "::mlir::StringAttr", methodName,
940 MethodParameter("::mlir::OperationName", "name"));
941 ERROR_IF_PRUNED(method, methodName, op);
942 method->body() << llvm::formatv(attrNameMethodBody,
943 "name, " + Twine(attrIt.index()));
944 }
945 }
946 }
947 }
948
949 // Emit the getter for an attribute with the return type specified.
950 // It is templated to be shared between the Op and the adaptor class.
951 template <typename OpClassOrAdaptor>
emitAttrGetterWithReturnType(FmtContext & fctx,OpClassOrAdaptor & opClass,const Operator & op,StringRef name,Attribute attr)952 static void emitAttrGetterWithReturnType(FmtContext &fctx,
953 OpClassOrAdaptor &opClass,
954 const Operator &op, StringRef name,
955 Attribute attr) {
956 auto *method = opClass.addMethod(attr.getReturnType(), name);
957 ERROR_IF_PRUNED(method, name, op);
958 auto &body = method->body();
959 body << " auto attr = " << name << "Attr();\n";
960 if (attr.hasDefaultValue()) {
961 // Returns the default value if not set.
962 // TODO: this is inefficient, we are recreating the attribute for every
963 // call. This should be set instead.
964 if (!attr.isConstBuildable()) {
965 PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() +
966 " must have a constBuilder");
967 }
968 std::string defaultValue = std::string(
969 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
970 body << " if (!attr)\n return "
971 << tgfmt(attr.getConvertFromStorageCall(),
972 &fctx.withSelf(defaultValue))
973 << ";\n";
974 }
975 body << " return "
976 << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr"))
977 << ";\n";
978 }
979
genAttrGetters()980 void OpEmitter::genAttrGetters() {
981 FmtContext fctx;
982 fctx.withBuilder("::mlir::Builder((*this)->getContext())");
983
984 // Emit the derived attribute body.
985 auto emitDerivedAttr = [&](StringRef name, Attribute attr) {
986 if (auto *method = opClass.addMethod(attr.getReturnType(), name))
987 method->body() << " " << attr.getDerivedCodeBody() << "\n";
988 };
989
990 // Generate named accessor with Attribute return type. This is a wrapper class
991 // that allows referring to the attributes via accessors instead of having to
992 // use the string interface for better compile time verification.
993 auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName,
994 Attribute attr) {
995 auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr");
996 if (!method)
997 return;
998 method->body() << formatv(
999 " return {0}.{1}<{2}>();", emitHelper.getAttr(attrName),
1000 attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null"
1001 : "cast",
1002 attr.getStorageType());
1003 };
1004
1005 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1006 for (StringRef name : op.getGetterNames(namedAttr.name)) {
1007 if (namedAttr.attr.isDerivedAttr()) {
1008 emitDerivedAttr(name, namedAttr.attr);
1009 } else {
1010 emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr);
1011 emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr);
1012 }
1013 }
1014 }
1015
1016 auto derivedAttrs = make_filter_range(op.getAttributes(),
1017 [](const NamedAttribute &namedAttr) {
1018 return namedAttr.attr.isDerivedAttr();
1019 });
1020 if (derivedAttrs.empty())
1021 return;
1022
1023 opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait");
1024 // Generate helper method to query whether a named attribute is a derived
1025 // attribute. This enables, for example, avoiding adding an attribute that
1026 // overlaps with a derived attribute.
1027 {
1028 auto *method =
1029 opClass.addStaticMethod("bool", "isDerivedAttribute",
1030 MethodParameter("::llvm::StringRef", "name"));
1031 ERROR_IF_PRUNED(method, "isDerivedAttribute", op);
1032 auto &body = method->body();
1033 for (auto namedAttr : derivedAttrs)
1034 body << " if (name == \"" << namedAttr.name << "\") return true;\n";
1035 body << " return false;";
1036 }
1037 // Generate method to materialize derived attributes as a DictionaryAttr.
1038 {
1039 auto *method = opClass.addMethod("::mlir::DictionaryAttr",
1040 "materializeDerivedAttributes");
1041 ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op);
1042 auto &body = method->body();
1043
1044 auto nonMaterializable =
1045 make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) {
1046 return namedAttr.attr.getConvertFromStorageCall().empty();
1047 });
1048 if (!nonMaterializable.empty()) {
1049 std::string attrs;
1050 llvm::raw_string_ostream os(attrs);
1051 interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) {
1052 os << op.getGetterName(attr.name);
1053 });
1054 PrintWarning(
1055 op.getLoc(),
1056 formatv(
1057 "op has non-materializable derived attributes '{0}', skipping",
1058 os.str()));
1059 body << formatv(" emitOpError(\"op has non-materializable derived "
1060 "attributes '{0}'\");\n",
1061 attrs);
1062 body << " return nullptr;";
1063 return;
1064 }
1065
1066 body << " ::mlir::MLIRContext* ctx = getContext();\n";
1067 body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n";
1068 body << " return ::mlir::DictionaryAttr::get(";
1069 body << " ctx, {\n";
1070 interleave(
1071 derivedAttrs, body,
1072 [&](const NamedAttribute &namedAttr) {
1073 auto tmpl = namedAttr.attr.getConvertFromStorageCall();
1074 std::string name = op.getGetterName(namedAttr.name);
1075 body << " {" << name << "AttrName(),\n"
1076 << tgfmt(tmpl, &fctx.withSelf(name + "()")
1077 .withBuilder("odsBuilder")
1078 .addSubst("_ctxt", "ctx"))
1079 << "}";
1080 },
1081 ",\n");
1082 body << "});";
1083 }
1084 }
1085
genAttrSetters()1086 void OpEmitter::genAttrSetters() {
1087 // Generate raw named setter type. This is a wrapper class that allows setting
1088 // to the attributes via setters instead of having to use the string interface
1089 // for better compile time verification.
1090 auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
1091 Attribute attr) {
1092 auto *method =
1093 opClass.addMethod("void", setterName + "Attr",
1094 MethodParameter(attr.getStorageType(), "attr"));
1095 if (method)
1096 method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
1097 getterName);
1098 };
1099
1100 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1101 if (namedAttr.attr.isDerivedAttr())
1102 continue;
1103 for (auto names : llvm::zip(op.getSetterNames(namedAttr.name),
1104 op.getGetterNames(namedAttr.name)))
1105 emitAttrWithStorageType(std::get<0>(names), std::get<1>(names),
1106 namedAttr.attr);
1107 }
1108 }
1109
genOptionalAttrRemovers()1110 void OpEmitter::genOptionalAttrRemovers() {
1111 // Generate methods for removing optional attributes, instead of having to
1112 // use the string interface. Enables better compile time verification.
1113 auto emitRemoveAttr = [&](StringRef name) {
1114 auto upperInitial = name.take_front().upper();
1115 auto suffix = name.drop_front();
1116 auto *method = opClass.addMethod("::mlir::Attribute",
1117 "remove" + upperInitial + suffix + "Attr");
1118 if (!method)
1119 return;
1120 method->body() << formatv(" return (*this)->removeAttr({0}AttrName());",
1121 op.getGetterName(name));
1122 };
1123
1124 for (const NamedAttribute &namedAttr : op.getAttributes())
1125 if (namedAttr.attr.isOptional())
1126 emitRemoveAttr(namedAttr.name);
1127 }
1128
1129 // Generates the code to compute the start and end index of an operand or result
1130 // range.
1131 template <typename RangeT>
1132 static void
generateValueRangeStartAndEnd(Class & opClass,StringRef methodName,int numVariadic,int numNonVariadic,StringRef rangeSizeCall,bool hasAttrSegmentSize,StringRef sizeAttrInit,RangeT && odsValues)1133 generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
1134 int numVariadic, int numNonVariadic,
1135 StringRef rangeSizeCall, bool hasAttrSegmentSize,
1136 StringRef sizeAttrInit, RangeT &&odsValues) {
1137 auto *method = opClass.addMethod("std::pair<unsigned, unsigned>", methodName,
1138 MethodParameter("unsigned", "index"));
1139 if (!method)
1140 return;
1141 auto &body = method->body();
1142 if (numVariadic == 0) {
1143 body << " return {index, 1};\n";
1144 } else if (hasAttrSegmentSize) {
1145 body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode;
1146 } else {
1147 // Because the op can have arbitrarily interleaved variadic and non-variadic
1148 // operands, we need to embed a list in the "sink" getter method for
1149 // calculation at run-time.
1150 SmallVector<StringRef, 4> isVariadic;
1151 isVariadic.reserve(llvm::size(odsValues));
1152 for (auto &it : odsValues)
1153 isVariadic.push_back(it.isVariableLength() ? "true" : "false");
1154 std::string isVariadicList = llvm::join(isVariadic, ", ");
1155 body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList,
1156 numNonVariadic, numVariadic, rangeSizeCall, "operand");
1157 }
1158 }
1159
1160 // Generates the named operand getter methods for the given Operator `op` and
1161 // puts them in `opClass`. Uses `rangeType` as the return type of getters that
1162 // return a range of operands (individual operands are `Value ` and each
1163 // element in the range must also be `Value `); use `rangeBeginCall` to get
1164 // an iterator to the beginning of the operand range; use `rangeSizeCall` to
1165 // obtain the number of operands. `getOperandCallPattern` contains the code
1166 // necessary to obtain a single operand whose position will be substituted
1167 // instead of
1168 // "{0}" marker in the pattern. Note that the pattern should work for any kind
1169 // of ops, in particular for one-operand ops that may not have the
1170 // `getOperand(unsigned)` method.
generateNamedOperandGetters(const Operator & op,Class & opClass,bool isAdaptor,StringRef sizeAttrInit,StringRef rangeType,StringRef rangeBeginCall,StringRef rangeSizeCall,StringRef getOperandCallPattern)1171 static void generateNamedOperandGetters(const Operator &op, Class &opClass,
1172 bool isAdaptor, StringRef sizeAttrInit,
1173 StringRef rangeType,
1174 StringRef rangeBeginCall,
1175 StringRef rangeSizeCall,
1176 StringRef getOperandCallPattern) {
1177 const int numOperands = op.getNumOperands();
1178 const int numVariadicOperands = op.getNumVariableLengthOperands();
1179 const int numNormalOperands = numOperands - numVariadicOperands;
1180
1181 const auto *sameVariadicSize =
1182 op.getTrait("::mlir::OpTrait::SameVariadicOperandSize");
1183 const auto *attrSizedOperands =
1184 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1185
1186 if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) {
1187 PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
1188 "specification over their sizes");
1189 }
1190
1191 if (numVariadicOperands < 2 && attrSizedOperands) {
1192 PrintFatalError(op.getLoc(), "op must have at least two variadic operands "
1193 "to use 'AttrSizedOperandSegments' trait");
1194 }
1195
1196 if (attrSizedOperands && sameVariadicSize) {
1197 PrintFatalError(op.getLoc(),
1198 "op cannot have both 'AttrSizedOperandSegments' and "
1199 "'SameVariadicOperandSize' traits");
1200 }
1201
1202 // First emit a few "sink" getter methods upon which we layer all nicer named
1203 // getter methods.
1204 generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength",
1205 numVariadicOperands, numNormalOperands,
1206 rangeSizeCall, attrSizedOperands, sizeAttrInit,
1207 const_cast<Operator &>(op).getOperands());
1208
1209 auto *m = opClass.addMethod(rangeType, "getODSOperands",
1210 MethodParameter("unsigned", "index"));
1211 ERROR_IF_PRUNED(m, "getODSOperands", op);
1212 auto &body = m->body();
1213 body << formatv(valueRangeReturnCode, rangeBeginCall,
1214 "getODSOperandIndexAndLength(index)");
1215
1216 // Then we emit nicer named getter methods by redirecting to the "sink" getter
1217 // method.
1218 for (int i = 0; i != numOperands; ++i) {
1219 const auto &operand = op.getOperand(i);
1220 if (operand.name.empty())
1221 continue;
1222 for (StringRef name : op.getGetterNames(operand.name)) {
1223 if (operand.isOptional()) {
1224 m = opClass.addMethod("::mlir::Value", name);
1225 ERROR_IF_PRUNED(m, name, op);
1226 m->body() << " auto operands = getODSOperands(" << i << ");\n"
1227 << " return operands.empty() ? ::mlir::Value() : "
1228 "*operands.begin();";
1229 } else if (operand.isVariadicOfVariadic()) {
1230 std::string segmentAttr = op.getGetterName(
1231 operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1232 if (isAdaptor) {
1233 m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>",
1234 name);
1235 ERROR_IF_PRUNED(m, name, op);
1236 m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
1237 segmentAttr, i);
1238 continue;
1239 }
1240
1241 m = opClass.addMethod("::mlir::OperandRangeRange", name);
1242 ERROR_IF_PRUNED(m, name, op);
1243 m->body() << " return getODSOperands(" << i << ").split("
1244 << segmentAttr << "Attr());";
1245 } else if (operand.isVariadic()) {
1246 m = opClass.addMethod(rangeType, name);
1247 ERROR_IF_PRUNED(m, name, op);
1248 m->body() << " return getODSOperands(" << i << ");";
1249 } else {
1250 m = opClass.addMethod("::mlir::Value", name);
1251 ERROR_IF_PRUNED(m, name, op);
1252 m->body() << " return *getODSOperands(" << i << ").begin();";
1253 }
1254 }
1255 }
1256 }
1257
genNamedOperandGetters()1258 void OpEmitter::genNamedOperandGetters() {
1259 // Build the code snippet used for initializing the operand_segment_size)s
1260 // array.
1261 std::string attrSizeInitCode;
1262 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
1263 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
1264 emitHelper.getAttr(operandSegmentAttrName));
1265 }
1266
1267 generateNamedOperandGetters(
1268 op, opClass,
1269 /*isAdaptor=*/false,
1270 /*sizeAttrInit=*/attrSizeInitCode,
1271 /*rangeType=*/"::mlir::Operation::operand_range",
1272 /*rangeBeginCall=*/"getOperation()->operand_begin()",
1273 /*rangeSizeCall=*/"getOperation()->getNumOperands()",
1274 /*getOperandCallPattern=*/"getOperation()->getOperand({0})");
1275 }
1276
genNamedOperandSetters()1277 void OpEmitter::genNamedOperandSetters() {
1278 auto *attrSizedOperands =
1279 op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments");
1280 for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
1281 const auto &operand = op.getOperand(i);
1282 if (operand.name.empty())
1283 continue;
1284 for (StringRef name : op.getGetterNames(operand.name)) {
1285 auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
1286 ? "::mlir::MutableOperandRangeRange"
1287 : "::mlir::MutableOperandRange",
1288 (name + "Mutable").str());
1289 ERROR_IF_PRUNED(m, name, op);
1290 auto &body = m->body();
1291 body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
1292 << " auto mutableRange = "
1293 "::mlir::MutableOperandRange(getOperation(), "
1294 "range.first, range.second";
1295 if (attrSizedOperands) {
1296 body << formatv(
1297 ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
1298 emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
1299 }
1300 body << ");\n";
1301
1302 // If this operand is a nested variadic, we split the range into a
1303 // MutableOperandRangeRange that provides a range over all of the
1304 // sub-ranges.
1305 if (operand.isVariadicOfVariadic()) {
1306 body << " return "
1307 "mutableRange.split(*(*this)->getAttrDictionary().getNamed("
1308 << op.getGetterName(
1309 operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
1310 << "AttrName()));\n";
1311 } else {
1312 // Otherwise, we use the full range directly.
1313 body << " return mutableRange;\n";
1314 }
1315 }
1316 }
1317 }
1318
genNamedResultGetters()1319 void OpEmitter::genNamedResultGetters() {
1320 const int numResults = op.getNumResults();
1321 const int numVariadicResults = op.getNumVariableLengthResults();
1322 const int numNormalResults = numResults - numVariadicResults;
1323
1324 // If we have more than one variadic results, we need more complicated logic
1325 // to calculate the value range for each result.
1326
1327 const auto *sameVariadicSize =
1328 op.getTrait("::mlir::OpTrait::SameVariadicResultSize");
1329 const auto *attrSizedResults =
1330 op.getTrait("::mlir::OpTrait::AttrSizedResultSegments");
1331
1332 if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) {
1333 PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
1334 "specification over their sizes");
1335 }
1336
1337 if (numVariadicResults < 2 && attrSizedResults) {
1338 PrintFatalError(op.getLoc(), "op must have at least two variadic results "
1339 "to use 'AttrSizedResultSegments' trait");
1340 }
1341
1342 if (attrSizedResults && sameVariadicSize) {
1343 PrintFatalError(op.getLoc(),
1344 "op cannot have both 'AttrSizedResultSegments' and "
1345 "'SameVariadicResultSize' traits");
1346 }
1347
1348 // Build the initializer string for the result segment size attribute.
1349 std::string attrSizeInitCode;
1350 if (attrSizedResults) {
1351 attrSizeInitCode = formatv(opSegmentSizeAttrInitCode,
1352 emitHelper.getAttr(resultSegmentAttrName));
1353 }
1354
1355 generateValueRangeStartAndEnd(
1356 opClass, "getODSResultIndexAndLength", numVariadicResults,
1357 numNormalResults, "getOperation()->getNumResults()", attrSizedResults,
1358 attrSizeInitCode, op.getResults());
1359
1360 auto *m =
1361 opClass.addMethod("::mlir::Operation::result_range", "getODSResults",
1362 MethodParameter("unsigned", "index"));
1363 ERROR_IF_PRUNED(m, "getODSResults", op);
1364 m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()",
1365 "getODSResultIndexAndLength(index)");
1366
1367 for (int i = 0; i != numResults; ++i) {
1368 const auto &result = op.getResult(i);
1369 if (result.name.empty())
1370 continue;
1371 for (StringRef name : op.getGetterNames(result.name)) {
1372 if (result.isOptional()) {
1373 m = opClass.addMethod("::mlir::Value", name);
1374 ERROR_IF_PRUNED(m, name, op);
1375 m->body()
1376 << " auto results = getODSResults(" << i << ");\n"
1377 << " return results.empty() ? ::mlir::Value() : *results.begin();";
1378 } else if (result.isVariadic()) {
1379 m = opClass.addMethod("::mlir::Operation::result_range", name);
1380 ERROR_IF_PRUNED(m, name, op);
1381 m->body() << " return getODSResults(" << i << ");";
1382 } else {
1383 m = opClass.addMethod("::mlir::Value", name);
1384 ERROR_IF_PRUNED(m, name, op);
1385 m->body() << " return *getODSResults(" << i << ").begin();";
1386 }
1387 }
1388 }
1389 }
1390
genNamedRegionGetters()1391 void OpEmitter::genNamedRegionGetters() {
1392 unsigned numRegions = op.getNumRegions();
1393 for (unsigned i = 0; i < numRegions; ++i) {
1394 const auto ®ion = op.getRegion(i);
1395 if (region.name.empty())
1396 continue;
1397
1398 for (StringRef name : op.getGetterNames(region.name)) {
1399 // Generate the accessors for a variadic region.
1400 if (region.isVariadic()) {
1401 auto *m =
1402 opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name);
1403 ERROR_IF_PRUNED(m, name, op);
1404 m->body() << formatv(" return (*this)->getRegions().drop_front({0});",
1405 i);
1406 continue;
1407 }
1408
1409 auto *m = opClass.addMethod("::mlir::Region &", name);
1410 ERROR_IF_PRUNED(m, name, op);
1411 m->body() << formatv(" return (*this)->getRegion({0});", i);
1412 }
1413 }
1414 }
1415
genNamedSuccessorGetters()1416 void OpEmitter::genNamedSuccessorGetters() {
1417 unsigned numSuccessors = op.getNumSuccessors();
1418 for (unsigned i = 0; i < numSuccessors; ++i) {
1419 const NamedSuccessor &successor = op.getSuccessor(i);
1420 if (successor.name.empty())
1421 continue;
1422
1423 for (StringRef name : op.getGetterNames(successor.name)) {
1424 // Generate the accessors for a variadic successor list.
1425 if (successor.isVariadic()) {
1426 auto *m = opClass.addMethod("::mlir::SuccessorRange", name);
1427 ERROR_IF_PRUNED(m, name, op);
1428 m->body() << formatv(
1429 " return {std::next((*this)->successor_begin(), {0}), "
1430 "(*this)->successor_end()};",
1431 i);
1432 continue;
1433 }
1434
1435 auto *m = opClass.addMethod("::mlir::Block *", name);
1436 ERROR_IF_PRUNED(m, name, op);
1437 m->body() << formatv(" return (*this)->getSuccessor({0});", i);
1438 }
1439 }
1440 }
1441
canGenerateUnwrappedBuilder(const Operator & op)1442 static bool canGenerateUnwrappedBuilder(const Operator &op) {
1443 // If this op does not have native attributes at all, return directly to avoid
1444 // redefining builders.
1445 if (op.getNumNativeAttributes() == 0)
1446 return false;
1447
1448 bool canGenerate = false;
1449 // We are generating builders that take raw values for attributes. We need to
1450 // make sure the native attributes have a meaningful "unwrapped" value type
1451 // different from the wrapped mlir::Attribute type to avoid redefining
1452 // builders. This checks for the op has at least one such native attribute.
1453 for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) {
1454 const NamedAttribute &namedAttr = op.getAttribute(i);
1455 if (canUseUnwrappedRawValue(namedAttr.attr)) {
1456 canGenerate = true;
1457 break;
1458 }
1459 }
1460 return canGenerate;
1461 }
1462
canInferType(const Operator & op)1463 static bool canInferType(const Operator &op) {
1464 return op.getTrait("::mlir::InferTypeOpInterface::Trait");
1465 }
1466
genSeparateArgParamBuilder()1467 void OpEmitter::genSeparateArgParamBuilder() {
1468 SmallVector<AttrParamKind, 2> attrBuilderType;
1469 attrBuilderType.push_back(AttrParamKind::WrappedAttr);
1470 if (canGenerateUnwrappedBuilder(op))
1471 attrBuilderType.push_back(AttrParamKind::UnwrappedValue);
1472
1473 // Emit with separate builders with or without unwrapped attributes and/or
1474 // inferring result type.
1475 auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind,
1476 bool inferType) {
1477 SmallVector<MethodParameter> paramList;
1478 SmallVector<std::string, 4> resultNames;
1479 llvm::StringSet<> inferredAttributes;
1480 buildParamList(paramList, inferredAttributes, resultNames, paramKind,
1481 attrType);
1482
1483 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1484 // If the builder is redundant, skip generating the method.
1485 if (!m)
1486 return;
1487 auto &body = m->body();
1488 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1489 /*isRawValueAttr=*/attrType ==
1490 AttrParamKind::UnwrappedValue);
1491
1492 // Push all result types to the operation state
1493
1494 if (inferType) {
1495 // Generate builder that infers type too.
1496 // TODO: Subsume this with general checking if type can be
1497 // inferred automatically.
1498 // TODO: Expand to handle regions.
1499 body << formatv(R"(
1500 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1501 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1502 {1}.location, {1}.operands,
1503 {1}.attributes.getDictionary({1}.getContext()),
1504 /*regions=*/{{}, inferredReturnTypes)))
1505 {1}.addTypes(inferredReturnTypes);
1506 else
1507 ::llvm::report_fatal_error("Failed to infer result type(s).");)",
1508 opClass.getClassName(), builderOpState);
1509 return;
1510 }
1511
1512 switch (paramKind) {
1513 case TypeParamKind::None:
1514 return;
1515 case TypeParamKind::Separate:
1516 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
1517 if (op.getResult(i).isOptional())
1518 body << " if (" << resultNames[i] << ")\n ";
1519 body << " " << builderOpState << ".addTypes(" << resultNames[i]
1520 << ");\n";
1521 }
1522 return;
1523 case TypeParamKind::Collective: {
1524 int numResults = op.getNumResults();
1525 int numVariadicResults = op.getNumVariableLengthResults();
1526 int numNonVariadicResults = numResults - numVariadicResults;
1527 bool hasVariadicResult = numVariadicResults != 0;
1528
1529 // Avoid emitting "resultTypes.size() >= 0u" which is always true.
1530 if (!(hasVariadicResult && numNonVariadicResults == 0))
1531 body << " "
1532 << "assert(resultTypes.size() "
1533 << (hasVariadicResult ? ">=" : "==") << " "
1534 << numNonVariadicResults
1535 << "u && \"mismatched number of results\");\n";
1536 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1537 }
1538 return;
1539 }
1540 llvm_unreachable("unhandled TypeParamKind");
1541 };
1542
1543 // Some of the build methods generated here may be ambiguous, but TableGen's
1544 // ambiguous function detection will elide those ones.
1545 for (auto attrType : attrBuilderType) {
1546 emit(attrType, TypeParamKind::Separate, /*inferType=*/false);
1547 if (canInferType(op) && op.getNumRegions() == 0)
1548 emit(attrType, TypeParamKind::None, /*inferType=*/true);
1549 emit(attrType, TypeParamKind::Collective, /*inferType=*/false);
1550 }
1551 }
1552
genUseOperandAsResultTypeCollectiveParamBuilder()1553 void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
1554 int numResults = op.getNumResults();
1555
1556 // Signature
1557 SmallVector<MethodParameter> paramList;
1558 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1559 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1560 paramList.emplace_back("::mlir::ValueRange", "operands");
1561 // Provide default value for `attributes` when its the last parameter
1562 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1563 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1564 "attributes", attributesDefaultValue);
1565 if (op.getNumVariadicRegions())
1566 paramList.emplace_back("unsigned", "numRegions");
1567
1568 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1569 // If the builder is redundant, skip generating the method
1570 if (!m)
1571 return;
1572 auto &body = m->body();
1573
1574 // Operands
1575 body << " " << builderOpState << ".addOperands(operands);\n";
1576
1577 // Attributes
1578 body << " " << builderOpState << ".addAttributes(attributes);\n";
1579
1580 // Create the correct number of regions
1581 if (int numRegions = op.getNumRegions()) {
1582 body << llvm::formatv(
1583 " for (unsigned i = 0; i != {0}; ++i)\n",
1584 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1585 body << " (void)" << builderOpState << ".addRegion();\n";
1586 }
1587
1588 // Result types
1589 SmallVector<std::string, 2> resultTypes(numResults, "operands[0].getType()");
1590 body << " " << builderOpState << ".addTypes({"
1591 << llvm::join(resultTypes, ", ") << "});\n\n";
1592 }
1593
genPopulateDefaultAttributes()1594 void OpEmitter::genPopulateDefaultAttributes() {
1595 // All done if no attributes have default values.
1596 if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
1597 return !named.attr.hasDefaultValue();
1598 }))
1599 return;
1600
1601 SmallVector<MethodParameter> paramList;
1602 paramList.emplace_back("const ::mlir::RegisteredOperationName &", "opName");
1603 paramList.emplace_back("::mlir::NamedAttrList &", "attributes");
1604 auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList);
1605 ERROR_IF_PRUNED(m, "populateDefaultAttrs", op);
1606 auto &body = m->body();
1607 body.indent();
1608
1609 // Set default attributes that are unset.
1610 body << "auto attrNames = opName.getAttributeNames();\n";
1611 body << "::mlir::Builder " << odsBuilder
1612 << "(attrNames.front().getContext());\n";
1613 StringMap<int> attrIndex;
1614 for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) {
1615 attrIndex[it.value().first] = it.index();
1616 }
1617 for (const NamedAttribute &namedAttr : op.getAttributes()) {
1618 auto &attr = namedAttr.attr;
1619 if (!attr.hasDefaultValue())
1620 continue;
1621 auto index = attrIndex[namedAttr.name];
1622 body << "if (!attributes.get(attrNames[" << index << "])) {\n";
1623 FmtContext fctx;
1624 fctx.withBuilder(odsBuilder);
1625 std::string defaultValue = std::string(
1626 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1627 body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
1628 index, defaultValue);
1629 body.unindent() << "}\n";
1630 }
1631 }
1632
genInferredTypeCollectiveParamBuilder()1633 void OpEmitter::genInferredTypeCollectiveParamBuilder() {
1634 SmallVector<MethodParameter> paramList;
1635 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1636 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1637 paramList.emplace_back("::mlir::ValueRange", "operands");
1638 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1639 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1640 "attributes", attributesDefaultValue);
1641 if (op.getNumVariadicRegions())
1642 paramList.emplace_back("unsigned", "numRegions");
1643
1644 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1645 // If the builder is redundant, skip generating the method
1646 if (!m)
1647 return;
1648 auto &body = m->body();
1649
1650 int numResults = op.getNumResults();
1651 int numVariadicResults = op.getNumVariableLengthResults();
1652 int numNonVariadicResults = numResults - numVariadicResults;
1653
1654 int numOperands = op.getNumOperands();
1655 int numVariadicOperands = op.getNumVariableLengthOperands();
1656 int numNonVariadicOperands = numOperands - numVariadicOperands;
1657
1658 // Operands
1659 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1660 body << " assert(operands.size()"
1661 << (numVariadicOperands != 0 ? " >= " : " == ")
1662 << numNonVariadicOperands
1663 << "u && \"mismatched number of parameters\");\n";
1664 body << " " << builderOpState << ".addOperands(operands);\n";
1665 body << " " << builderOpState << ".addAttributes(attributes);\n";
1666
1667 // Create the correct number of regions
1668 if (int numRegions = op.getNumRegions()) {
1669 body << llvm::formatv(
1670 " for (unsigned i = 0; i != {0}; ++i)\n",
1671 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1672 body << " (void)" << builderOpState << ".addRegion();\n";
1673 }
1674
1675 // Result types
1676 body << formatv(R"(
1677 ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
1678 if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(),
1679 {1}.location, operands,
1680 {1}.attributes.getDictionary({1}.getContext()),
1681 {1}.regions, inferredReturnTypes))) {{)",
1682 opClass.getClassName(), builderOpState);
1683 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1684 body << "\n assert(inferredReturnTypes.size()"
1685 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1686 << "u && \"mismatched number of return types\");";
1687 body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);";
1688
1689 body << formatv(R"(
1690 } else {{
1691 ::llvm::report_fatal_error("Failed to infer result type(s).");
1692 })",
1693 opClass.getClassName(), builderOpState);
1694 }
1695
genUseOperandAsResultTypeSeparateParamBuilder()1696 void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
1697 auto emit = [&](AttrParamKind attrType) {
1698 SmallVector<MethodParameter> paramList;
1699 SmallVector<std::string, 4> resultNames;
1700 llvm::StringSet<> inferredAttributes;
1701 buildParamList(paramList, inferredAttributes, resultNames,
1702 TypeParamKind::None, attrType);
1703
1704 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1705 // If the builder is redundant, skip generating the method
1706 if (!m)
1707 return;
1708 auto &body = m->body();
1709 genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
1710 /*isRawValueAttr=*/attrType ==
1711 AttrParamKind::UnwrappedValue);
1712
1713 auto numResults = op.getNumResults();
1714 if (numResults == 0)
1715 return;
1716
1717 // Push all result types to the operation state
1718 const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
1719 std::string resultType =
1720 formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str();
1721 body << " " << builderOpState << ".addTypes({" << resultType;
1722 for (int i = 1; i != numResults; ++i)
1723 body << ", " << resultType;
1724 body << "});\n\n";
1725 };
1726
1727 emit(AttrParamKind::WrappedAttr);
1728 // Generate additional builder(s) if any attributes can be "unwrapped"
1729 if (canGenerateUnwrappedBuilder(op))
1730 emit(AttrParamKind::UnwrappedValue);
1731 }
1732
genUseAttrAsResultTypeBuilder()1733 void OpEmitter::genUseAttrAsResultTypeBuilder() {
1734 SmallVector<MethodParameter> paramList;
1735 paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder");
1736 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1737 paramList.emplace_back("::mlir::ValueRange", "operands");
1738 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1739 "attributes", "{}");
1740 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1741 // If the builder is redundant, skip generating the method
1742 if (!m)
1743 return;
1744
1745 auto &body = m->body();
1746
1747 // Push all result types to the operation state
1748 std::string resultType;
1749 const auto &namedAttr = op.getAttribute(0);
1750
1751 body << " auto attrName = " << op.getGetterName(namedAttr.name)
1752 << "AttrName(" << builderOpState
1753 << ".name);\n"
1754 " for (auto attr : attributes) {\n"
1755 " if (attr.getName() != attrName) continue;\n";
1756 if (namedAttr.attr.isTypeAttr()) {
1757 resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()";
1758 } else {
1759 resultType = "attr.getValue().getType()";
1760 }
1761
1762 // Operands
1763 body << " " << builderOpState << ".addOperands(operands);\n";
1764
1765 // Attributes
1766 body << " " << builderOpState << ".addAttributes(attributes);\n";
1767
1768 // Result types
1769 SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1770 body << " " << builderOpState << ".addTypes({"
1771 << llvm::join(resultTypes, ", ") << "});\n";
1772 body << " }\n";
1773 }
1774
1775 /// Returns a signature of the builder. Updates the context `fctx` to enable
1776 /// replacement of $_builder and $_state in the body.
1777 static SmallVector<MethodParameter>
getBuilderSignature(const Builder & builder)1778 getBuilderSignature(const Builder &builder) {
1779 ArrayRef<Builder::Parameter> params(builder.getParameters());
1780
1781 // Inject builder and state arguments.
1782 SmallVector<MethodParameter> arguments;
1783 arguments.reserve(params.size() + 2);
1784 arguments.emplace_back("::mlir::OpBuilder &", odsBuilder);
1785 arguments.emplace_back("::mlir::OperationState &", builderOpState);
1786
1787 for (unsigned i = 0, e = params.size(); i < e; ++i) {
1788 // If no name is provided, generate one.
1789 Optional<StringRef> paramName = params[i].getName();
1790 std::string name =
1791 paramName ? paramName->str() : "odsArg" + std::to_string(i);
1792
1793 StringRef defaultValue;
1794 if (Optional<StringRef> defaultParamValue = params[i].getDefaultValue())
1795 defaultValue = *defaultParamValue;
1796
1797 arguments.emplace_back(params[i].getCppType(), std::move(name),
1798 defaultValue);
1799 }
1800
1801 return arguments;
1802 }
1803
genBuilder()1804 void OpEmitter::genBuilder() {
1805 // Handle custom builders if provided.
1806 for (const Builder &builder : op.getBuilders()) {
1807 SmallVector<MethodParameter> arguments = getBuilderSignature(builder);
1808
1809 Optional<StringRef> body = builder.getBody();
1810 auto properties = body ? Method::Static : Method::StaticDeclaration;
1811 auto *method =
1812 opClass.addMethod("void", "build", properties, std::move(arguments));
1813 if (body)
1814 ERROR_IF_PRUNED(method, "build", op);
1815
1816 FmtContext fctx;
1817 fctx.withBuilder(odsBuilder);
1818 fctx.addSubst("_state", builderOpState);
1819 if (body)
1820 method->body() << tgfmt(*body, &fctx);
1821 }
1822
1823 // Generate default builders that requires all result type, operands, and
1824 // attributes as parameters.
1825 if (op.skipDefaultBuilders())
1826 return;
1827
1828 // We generate three classes of builders here:
1829 // 1. one having a stand-alone parameter for each operand / attribute, and
1830 genSeparateArgParamBuilder();
1831 // 2. one having an aggregated parameter for all result types / operands /
1832 // attributes, and
1833 genCollectiveParamBuilder();
1834 // 3. one having a stand-alone parameter for each operand and attribute,
1835 // use the first operand or attribute's type as all result types
1836 // to facilitate different call patterns.
1837 if (op.getNumVariableLengthResults() == 0) {
1838 if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
1839 genUseOperandAsResultTypeSeparateParamBuilder();
1840 genUseOperandAsResultTypeCollectiveParamBuilder();
1841 }
1842 if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"))
1843 genUseAttrAsResultTypeBuilder();
1844 }
1845 }
1846
genCollectiveParamBuilder()1847 void OpEmitter::genCollectiveParamBuilder() {
1848 int numResults = op.getNumResults();
1849 int numVariadicResults = op.getNumVariableLengthResults();
1850 int numNonVariadicResults = numResults - numVariadicResults;
1851
1852 int numOperands = op.getNumOperands();
1853 int numVariadicOperands = op.getNumVariableLengthOperands();
1854 int numNonVariadicOperands = numOperands - numVariadicOperands;
1855
1856 SmallVector<MethodParameter> paramList;
1857 paramList.emplace_back("::mlir::OpBuilder &", "");
1858 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1859 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1860 paramList.emplace_back("::mlir::ValueRange", "operands");
1861 // Provide default value for `attributes` when its the last parameter
1862 StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}";
1863 paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
1864 "attributes", attributesDefaultValue);
1865 if (op.getNumVariadicRegions())
1866 paramList.emplace_back("unsigned", "numRegions");
1867
1868 auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
1869 // If the builder is redundant, skip generating the method
1870 if (!m)
1871 return;
1872 auto &body = m->body();
1873
1874 // Operands
1875 if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
1876 body << " assert(operands.size()"
1877 << (numVariadicOperands != 0 ? " >= " : " == ")
1878 << numNonVariadicOperands
1879 << "u && \"mismatched number of parameters\");\n";
1880 body << " " << builderOpState << ".addOperands(operands);\n";
1881
1882 // Attributes
1883 body << " " << builderOpState << ".addAttributes(attributes);\n";
1884
1885 // Create the correct number of regions
1886 if (int numRegions = op.getNumRegions()) {
1887 body << llvm::formatv(
1888 " for (unsigned i = 0; i != {0}; ++i)\n",
1889 (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions)));
1890 body << " (void)" << builderOpState << ".addRegion();\n";
1891 }
1892
1893 // Result types
1894 if (numVariadicResults == 0 || numNonVariadicResults != 0)
1895 body << " assert(resultTypes.size()"
1896 << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
1897 << "u && \"mismatched number of return types\");\n";
1898 body << " " << builderOpState << ".addTypes(resultTypes);\n";
1899
1900 // Generate builder that infers type too.
1901 // TODO: Expand to handle successors.
1902 if (canInferType(op) && op.getNumSuccessors() == 0)
1903 genInferredTypeCollectiveParamBuilder();
1904 }
1905
buildParamList(SmallVectorImpl<MethodParameter> & paramList,llvm::StringSet<> & inferredAttributes,SmallVectorImpl<std::string> & resultTypeNames,TypeParamKind typeParamKind,AttrParamKind attrParamKind)1906 void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList,
1907 llvm::StringSet<> &inferredAttributes,
1908 SmallVectorImpl<std::string> &resultTypeNames,
1909 TypeParamKind typeParamKind,
1910 AttrParamKind attrParamKind) {
1911 resultTypeNames.clear();
1912 auto numResults = op.getNumResults();
1913 resultTypeNames.reserve(numResults);
1914
1915 paramList.emplace_back("::mlir::OpBuilder &", odsBuilder);
1916 paramList.emplace_back("::mlir::OperationState &", builderOpState);
1917
1918 switch (typeParamKind) {
1919 case TypeParamKind::None:
1920 break;
1921 case TypeParamKind::Separate: {
1922 // Add parameters for all return types
1923 for (int i = 0; i < numResults; ++i) {
1924 const auto &result = op.getResult(i);
1925 std::string resultName = std::string(result.name);
1926 if (resultName.empty())
1927 resultName = std::string(formatv("resultType{0}", i));
1928
1929 StringRef type =
1930 result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type";
1931
1932 paramList.emplace_back(type, resultName, result.isOptional());
1933 resultTypeNames.emplace_back(std::move(resultName));
1934 }
1935 } break;
1936 case TypeParamKind::Collective: {
1937 paramList.emplace_back("::mlir::TypeRange", "resultTypes");
1938 resultTypeNames.push_back("resultTypes");
1939 } break;
1940 }
1941
1942 // Add parameters for all arguments (operands and attributes).
1943 int defaultValuedAttrStartIndex = op.getNumArgs();
1944 // Successors and variadic regions go at the end of the parameter list, so no
1945 // default arguments are possible.
1946 bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions();
1947 if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) {
1948 // Calculate the start index from which we can attach default values in the
1949 // builder declaration.
1950 for (int i = op.getNumArgs() - 1; i >= 0; --i) {
1951 auto *namedAttr = op.getArg(i).dyn_cast<tblgen::NamedAttribute *>();
1952 if (!namedAttr || !namedAttr->attr.hasDefaultValue())
1953 break;
1954
1955 if (!canUseUnwrappedRawValue(namedAttr->attr))
1956 break;
1957
1958 // Creating an APInt requires us to provide bitwidth, value, and
1959 // signedness, which is complicated compared to others. Similarly
1960 // for APFloat.
1961 // TODO: Adjust the 'returnType' field of such attributes
1962 // to support them.
1963 StringRef retType = namedAttr->attr.getReturnType();
1964 if (retType == "::llvm::APInt" || retType == "::llvm::APFloat")
1965 break;
1966
1967 defaultValuedAttrStartIndex = i;
1968 }
1969 }
1970
1971 /// Collect any inferred attributes.
1972 for (const NamedTypeConstraint &operand : op.getOperands()) {
1973 if (operand.isVariadicOfVariadic()) {
1974 inferredAttributes.insert(
1975 operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
1976 }
1977 }
1978
1979 for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
1980 Argument arg = op.getArg(i);
1981 if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
1982 StringRef type;
1983 if (operand->isVariadicOfVariadic())
1984 type = "::llvm::ArrayRef<::mlir::ValueRange>";
1985 else if (operand->isVariadic())
1986 type = "::mlir::ValueRange";
1987 else
1988 type = "::mlir::Value";
1989
1990 paramList.emplace_back(type, getArgumentName(op, numOperands++),
1991 operand->isOptional());
1992 continue;
1993 }
1994 const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
1995 const Attribute &attr = namedAttr.attr;
1996
1997 // Inferred attributes don't need to be added to the param list.
1998 if (inferredAttributes.contains(namedAttr.name))
1999 continue;
2000
2001 StringRef type;
2002 switch (attrParamKind) {
2003 case AttrParamKind::WrappedAttr:
2004 type = attr.getStorageType();
2005 break;
2006 case AttrParamKind::UnwrappedValue:
2007 if (canUseUnwrappedRawValue(attr))
2008 type = attr.getReturnType();
2009 else
2010 type = attr.getStorageType();
2011 break;
2012 }
2013
2014 // Attach default value if requested and possible.
2015 std::string defaultValue;
2016 if (attrParamKind == AttrParamKind::UnwrappedValue &&
2017 i >= defaultValuedAttrStartIndex) {
2018 defaultValue += attr.getDefaultValue();
2019 }
2020 paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue),
2021 attr.isOptional());
2022 }
2023
2024 /// Insert parameters for each successor.
2025 for (const NamedSuccessor &succ : op.getSuccessors()) {
2026 StringRef type =
2027 succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *";
2028 paramList.emplace_back(type, succ.name);
2029 }
2030
2031 /// Insert parameters for variadic regions.
2032 for (const NamedRegion ®ion : op.getRegions())
2033 if (region.isVariadic())
2034 paramList.emplace_back("unsigned",
2035 llvm::formatv("{0}Count", region.name).str());
2036 }
2037
genCodeForAddingArgAndRegionForBuilder(MethodBody & body,llvm::StringSet<> & inferredAttributes,bool isRawValueAttr)2038 void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
2039 MethodBody &body, llvm::StringSet<> &inferredAttributes,
2040 bool isRawValueAttr) {
2041 // Push all operands to the result.
2042 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
2043 std::string argName = getArgumentName(op, i);
2044 const NamedTypeConstraint &operand = op.getOperand(i);
2045 if (operand.constraint.isVariadicOfVariadic()) {
2046 body << " for (::mlir::ValueRange range : " << argName << ")\n "
2047 << builderOpState << ".addOperands(range);\n";
2048
2049 // Add the segment attribute.
2050 body << " {\n"
2051 << " ::llvm::SmallVector<int32_t> rangeSegments;\n"
2052 << " for (::mlir::ValueRange range : " << argName << ")\n"
2053 << " rangeSegments.push_back(range.size());\n"
2054 << " " << builderOpState << ".addAttribute("
2055 << op.getGetterName(
2056 operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
2057 << "AttrName(" << builderOpState << ".name), " << odsBuilder
2058 << ".getI32TensorAttr(rangeSegments));"
2059 << " }\n";
2060 continue;
2061 }
2062
2063 if (operand.isOptional())
2064 body << " if (" << argName << ")\n ";
2065 body << " " << builderOpState << ".addOperands(" << argName << ");\n";
2066 }
2067
2068 // If the operation has the operand segment size attribute, add it here.
2069 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2070 std::string sizes = op.getGetterName(operandSegmentAttrName);
2071 body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName("
2072 << builderOpState << ".name), "
2073 << "odsBuilder.getI32VectorAttr({";
2074 interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
2075 const NamedTypeConstraint &operand = op.getOperand(i);
2076 if (!operand.isVariableLength()) {
2077 body << "1";
2078 return;
2079 }
2080
2081 std::string operandName = getArgumentName(op, i);
2082 if (operand.isOptional()) {
2083 body << "(" << operandName << " ? 1 : 0)";
2084 } else if (operand.isVariadicOfVariadic()) {
2085 body << llvm::formatv(
2086 "static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
2087 "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
2088 "range.size(); }))",
2089 operandName);
2090 } else {
2091 body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
2092 }
2093 });
2094 body << "}));\n";
2095 }
2096
2097 // Push all attributes to the result.
2098 for (const auto &namedAttr : op.getAttributes()) {
2099 auto &attr = namedAttr.attr;
2100 if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
2101 continue;
2102
2103 bool emitNotNullCheck =
2104 attr.isOptional() || (attr.hasDefaultValue() && !isRawValueAttr);
2105 if (emitNotNullCheck)
2106 body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
2107
2108 if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
2109 // If this is a raw value, then we need to wrap it in an Attribute
2110 // instance.
2111 FmtContext fctx;
2112 fctx.withBuilder("odsBuilder");
2113
2114 std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
2115
2116 // For StringAttr, its constant builder call will wrap the input in
2117 // quotes, which is correct for normal string literals, but incorrect
2118 // here given we use function arguments. So we need to strip the
2119 // wrapping quotes.
2120 if (StringRef(builderTemplate).contains("\"$0\""))
2121 builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
2122
2123 std::string value =
2124 std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
2125 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
2126 builderOpState, op.getGetterName(namedAttr.name), value);
2127 } else {
2128 body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
2129 builderOpState, op.getGetterName(namedAttr.name),
2130 namedAttr.name);
2131 }
2132 if (emitNotNullCheck)
2133 body << " }\n";
2134 }
2135
2136 // Create the correct number of regions.
2137 for (const NamedRegion ®ion : op.getRegions()) {
2138 if (region.isVariadic())
2139 body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ",
2140 region.name);
2141
2142 body << " (void)" << builderOpState << ".addRegion();\n";
2143 }
2144
2145 // Push all successors to the result.
2146 for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) {
2147 body << formatv(" {0}.addSuccessors({1});\n", builderOpState,
2148 namedSuccessor.name);
2149 }
2150 }
2151
genCanonicalizerDecls()2152 void OpEmitter::genCanonicalizerDecls() {
2153 bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod");
2154 if (hasCanonicalizeMethod) {
2155 // static LogicResult FooOp::
2156 // canonicalize(FooOp op, PatternRewriter &rewriter);
2157 SmallVector<MethodParameter> paramList;
2158 paramList.emplace_back(op.getCppClassName(), "op");
2159 paramList.emplace_back("::mlir::PatternRewriter &", "rewriter");
2160 auto *m = opClass.declareStaticMethod("::mlir::LogicalResult",
2161 "canonicalize", std::move(paramList));
2162 ERROR_IF_PRUNED(m, "canonicalize", op);
2163 }
2164
2165 // We get a prototype for 'getCanonicalizationPatterns' if requested directly
2166 // or if using a 'canonicalize' method.
2167 bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer");
2168 if (!hasCanonicalizeMethod && !hasCanonicalizer)
2169 return;
2170
2171 // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize'
2172 // method, but not implementing 'getCanonicalizationPatterns' manually.
2173 bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer;
2174
2175 // Add a signature for getCanonicalizationPatterns if implemented by the
2176 // dialect or if synthesized to call 'canonicalize'.
2177 SmallVector<MethodParameter> paramList;
2178 paramList.emplace_back("::mlir::RewritePatternSet &", "results");
2179 paramList.emplace_back("::mlir::MLIRContext *", "context");
2180 auto kind = hasBody ? Method::Static : Method::StaticDeclaration;
2181 auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind,
2182 std::move(paramList));
2183
2184 // If synthesizing the method, fill it it.
2185 if (hasBody) {
2186 ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op);
2187 method->body() << " results.add(canonicalize);\n";
2188 }
2189 }
2190
genFolderDecls()2191 void OpEmitter::genFolderDecls() {
2192 bool hasSingleResult =
2193 op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0;
2194
2195 if (def.getValueAsBit("hasFolder")) {
2196 if (hasSingleResult) {
2197 auto *m = opClass.declareMethod(
2198 "::mlir::OpFoldResult", "fold",
2199 MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands"));
2200 ERROR_IF_PRUNED(m, "operands", op);
2201 } else {
2202 SmallVector<MethodParameter> paramList;
2203 paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands");
2204 paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &",
2205 "results");
2206 auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold",
2207 std::move(paramList));
2208 ERROR_IF_PRUNED(m, "fold", op);
2209 }
2210 }
2211 }
2212
genOpInterfaceMethods(const tblgen::InterfaceTrait * opTrait)2213 void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
2214 Interface interface = opTrait->getInterface();
2215
2216 // Get the set of methods that should always be declared.
2217 auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods();
2218 llvm::StringSet<> alwaysDeclaredMethods;
2219 alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(),
2220 alwaysDeclaredMethodsVec.end());
2221
2222 for (const InterfaceMethod &method : interface.getMethods()) {
2223 // Don't declare if the method has a body.
2224 if (method.getBody())
2225 continue;
2226 // Don't declare if the method has a default implementation and the op
2227 // didn't request that it always be declared.
2228 if (method.getDefaultImplementation() &&
2229 !alwaysDeclaredMethods.count(method.getName()))
2230 continue;
2231 // Interface methods are allowed to overlap with existing methods, so don't
2232 // check if pruned.
2233 (void)genOpInterfaceMethod(method);
2234 }
2235 }
2236
genOpInterfaceMethod(const InterfaceMethod & method,bool declaration)2237 Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
2238 bool declaration) {
2239 SmallVector<MethodParameter> paramList;
2240 for (const InterfaceMethod::Argument &arg : method.getArguments())
2241 paramList.emplace_back(arg.type, arg.name);
2242
2243 auto props = (method.isStatic() ? Method::Static : Method::None) |
2244 (declaration ? Method::Declaration : Method::None);
2245 return opClass.addMethod(method.getReturnType(), method.getName(), props,
2246 std::move(paramList));
2247 }
2248
genOpInterfaceMethods()2249 void OpEmitter::genOpInterfaceMethods() {
2250 for (const auto &trait : op.getTraits()) {
2251 if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
2252 if (opTrait->shouldDeclareMethods())
2253 genOpInterfaceMethods(opTrait);
2254 }
2255 }
2256
genSideEffectInterfaceMethods()2257 void OpEmitter::genSideEffectInterfaceMethods() {
2258 enum EffectKind { Operand, Result, Symbol, Static };
2259 struct EffectLocation {
2260 /// The effect applied.
2261 SideEffect effect;
2262
2263 /// The index if the kind is not static.
2264 unsigned index;
2265
2266 /// The kind of the location.
2267 unsigned kind;
2268 };
2269
2270 StringMap<SmallVector<EffectLocation, 1>> interfaceEffects;
2271 auto resolveDecorators = [&](Operator::var_decorator_range decorators,
2272 unsigned index, unsigned kind) {
2273 for (auto decorator : decorators)
2274 if (SideEffect *effect = dyn_cast<SideEffect>(&decorator)) {
2275 opClass.addTrait(effect->getInterfaceTrait());
2276 interfaceEffects[effect->getBaseEffectName()].push_back(
2277 EffectLocation{*effect, index, kind});
2278 }
2279 };
2280
2281 // Collect effects that were specified via:
2282 /// Traits.
2283 for (const auto &trait : op.getTraits()) {
2284 const auto *opTrait = dyn_cast<tblgen::SideEffectTrait>(&trait);
2285 if (!opTrait)
2286 continue;
2287 auto &effects = interfaceEffects[opTrait->getBaseEffectName()];
2288 for (auto decorator : opTrait->getEffects())
2289 effects.push_back(EffectLocation{cast<SideEffect>(decorator),
2290 /*index=*/0, EffectKind::Static});
2291 }
2292 /// Attributes and Operands.
2293 for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) {
2294 Argument arg = op.getArg(i);
2295 if (arg.is<NamedTypeConstraint *>()) {
2296 resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand);
2297 ++operandIt;
2298 continue;
2299 }
2300 const NamedAttribute *attr = arg.get<NamedAttribute *>();
2301 if (attr->attr.getBaseAttr().isSymbolRefAttr())
2302 resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol);
2303 }
2304 /// Results.
2305 for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
2306 resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result);
2307
2308 // The code used to add an effect instance.
2309 // {0}: The effect class.
2310 // {1}: Optional value or symbol reference.
2311 // {1}: The resource class.
2312 const char *addEffectCode =
2313 " effects.emplace_back({0}::get(), {1}{2}::get());\n";
2314
2315 for (auto &it : interfaceEffects) {
2316 // Generate the 'getEffects' method.
2317 std::string type = llvm::formatv("::llvm::SmallVectorImpl<::mlir::"
2318 "SideEffects::EffectInstance<{0}>> &",
2319 it.first())
2320 .str();
2321 auto *getEffects = opClass.addMethod("void", "getEffects",
2322 MethodParameter(type, "effects"));
2323 ERROR_IF_PRUNED(getEffects, "getEffects", op);
2324 auto &body = getEffects->body();
2325
2326 // Add effect instances for each of the locations marked on the operation.
2327 for (auto &location : it.second) {
2328 StringRef effect = location.effect.getName();
2329 StringRef resource = location.effect.getResource();
2330 if (location.kind == EffectKind::Static) {
2331 // A static instance has no attached value.
2332 body << llvm::formatv(addEffectCode, effect, "", resource).str();
2333 } else if (location.kind == EffectKind::Symbol) {
2334 // A symbol reference requires adding the proper attribute.
2335 const auto *attr = op.getArg(location.index).get<NamedAttribute *>();
2336 std::string argName = op.getGetterName(attr->name);
2337 if (attr->attr.isOptional()) {
2338 body << " if (auto symbolRef = " << argName << "Attr())\n "
2339 << llvm::formatv(addEffectCode, effect, "symbolRef, ", resource)
2340 .str();
2341 } else {
2342 body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
2343 resource)
2344 .str();
2345 }
2346 } else {
2347 // Otherwise this is an operand/result, so we need to attach the Value.
2348 body << " for (::mlir::Value value : getODS"
2349 << (location.kind == EffectKind::Operand ? "Operands" : "Results")
2350 << "(" << location.index << "))\n "
2351 << llvm::formatv(addEffectCode, effect, "value, ", resource).str();
2352 }
2353 }
2354 }
2355 }
2356
genTypeInterfaceMethods()2357 void OpEmitter::genTypeInterfaceMethods() {
2358 if (!op.allResultTypesKnown())
2359 return;
2360 // Generate 'inferReturnTypes' method declaration using the interface method
2361 // declared in 'InferTypeOpInterface' op interface.
2362 const auto *trait =
2363 cast<InterfaceTrait>(op.getTrait("::mlir::InferTypeOpInterface::Trait"));
2364 Interface interface = trait->getInterface();
2365 Method *method = [&]() -> Method * {
2366 for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
2367 if (interfaceMethod.getName() == "inferReturnTypes") {
2368 return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
2369 }
2370 }
2371 assert(0 && "unable to find inferReturnTypes interface method");
2372 return nullptr;
2373 }();
2374 ERROR_IF_PRUNED(method, "inferReturnTypes", op);
2375 auto &body = method->body();
2376 body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
2377
2378 FmtContext fctx;
2379 fctx.withBuilder("odsBuilder");
2380 body << " ::mlir::Builder odsBuilder(context);\n";
2381
2382 // Preprocess the result types and build all of the types used during
2383 // inferrence. This limits the amount of duplicated work when a type is used
2384 // to infer multiple others.
2385 llvm::DenseMap<Constraint, int> constraintsTypes;
2386 llvm::DenseMap<int, int> argumentsTypes;
2387 int inferredTypeIdx = 0;
2388 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2389 auto type = op.getSameTypeAsResult(i).front();
2390
2391 // If the type isn't an argument, it refers to a buildable type.
2392 if (!type.isArg()) {
2393 auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
2394 if (!it.second)
2395 continue;
2396
2397 // If we haven't seen this constraint, generate a variable for it.
2398 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
2399 << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
2400 continue;
2401 }
2402
2403 // Otherwise, this is an argument.
2404 int argIndex = type.getArg();
2405 auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
2406 if (!it.second)
2407 continue;
2408 body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
2409
2410 // If this is an operand, just index into operand list to access the type.
2411 auto arg = op.getArgToOperandOrAttribute(argIndex);
2412 if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
2413 body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
2414
2415 // If this is an attribute, index into the attribute dictionary.
2416 } else {
2417 auto *attr =
2418 op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
2419 body << "attributes.get(\"" << attr->name << "\").getType()";
2420 }
2421 body << ";\n";
2422 }
2423
2424 // Perform a second pass that handles assigning the inferred types to the
2425 // results.
2426 for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2427 auto types = op.getSameTypeAsResult(i);
2428
2429 // Append the inferred type.
2430 auto type = types.front();
2431 body << " inferredReturnTypes[" << i << "] = odsInferredType"
2432 << (type.isArg() ? argumentsTypes[type.getArg()]
2433 : constraintsTypes[type.getType()])
2434 << ";\n";
2435
2436 if (types.size() == 1)
2437 continue;
2438 // TODO: We could verify equality here, but skipping that for verification.
2439 }
2440 body << " return ::mlir::success();";
2441 }
2442
genParser()2443 void OpEmitter::genParser() {
2444 if (hasStringAttribute(def, "assemblyFormat"))
2445 return;
2446
2447 if (!def.getValueAsBit("hasCustomAssemblyFormat"))
2448 return;
2449
2450 SmallVector<MethodParameter> paramList;
2451 paramList.emplace_back("::mlir::OpAsmParser &", "parser");
2452 paramList.emplace_back("::mlir::OperationState &", "result");
2453
2454 auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse",
2455 std::move(paramList));
2456 ERROR_IF_PRUNED(method, "parse", op);
2457 }
2458
genPrinter()2459 void OpEmitter::genPrinter() {
2460 if (hasStringAttribute(def, "assemblyFormat"))
2461 return;
2462
2463 // Check to see if this op uses a c++ format.
2464 if (!def.getValueAsBit("hasCustomAssemblyFormat"))
2465 return;
2466 auto *method = opClass.declareMethod(
2467 "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p"));
2468 ERROR_IF_PRUNED(method, "print", op);
2469 }
2470
genVerifier()2471 void OpEmitter::genVerifier() {
2472 auto *implMethod =
2473 opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
2474 ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
2475 auto &implBody = implMethod->body();
2476
2477 populateSubstitutions(emitHelper, verifyCtx);
2478 genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
2479 genOperandResultVerifier(implBody, op.getOperands(), "operand");
2480 genOperandResultVerifier(implBody, op.getResults(), "result");
2481
2482 for (auto &trait : op.getTraits()) {
2483 if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
2484 implBody << tgfmt(" if (!($0))\n "
2485 "return emitOpError(\"failed to verify that $1\");\n",
2486 &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
2487 t->getSummary());
2488 }
2489 }
2490
2491 genRegionVerifier(implBody);
2492 genSuccessorVerifier(implBody);
2493
2494 implBody << " return ::mlir::success();\n";
2495
2496 // TODO: Some places use the `verifyInvariants` to do operation verification.
2497 // This may not act as their expectation because this doesn't call any
2498 // verifiers of native/interface traits. Needs to review those use cases and
2499 // see if we should use the mlir::verify() instead.
2500 auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
2501 ERROR_IF_PRUNED(method, "verifyInvariants", op);
2502 auto &body = method->body();
2503 if (def.getValueAsBit("hasVerifier")) {
2504 body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
2505 "::mlir::succeeded(verify()))\n";
2506 body << " return ::mlir::success();\n";
2507 body << " return ::mlir::failure();";
2508 } else {
2509 body << " return verifyInvariantsImpl();";
2510 }
2511 }
2512
genCustomVerifier()2513 void OpEmitter::genCustomVerifier() {
2514 if (def.getValueAsBit("hasVerifier")) {
2515 auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify");
2516 ERROR_IF_PRUNED(method, "verify", op);
2517 }
2518
2519 if (def.getValueAsBit("hasRegionVerifier")) {
2520 auto *method =
2521 opClass.declareMethod("::mlir::LogicalResult", "verifyRegions");
2522 ERROR_IF_PRUNED(method, "verifyRegions", op);
2523 }
2524 }
2525
genOperandResultVerifier(MethodBody & body,Operator::const_value_range values,StringRef valueKind)2526 void OpEmitter::genOperandResultVerifier(MethodBody &body,
2527 Operator::const_value_range values,
2528 StringRef valueKind) {
2529 // Check that an optional value is at most 1 element.
2530 //
2531 // {0}: Value index.
2532 // {1}: "operand" or "result"
2533 const char *const verifyOptional = R"(
2534 if (valueGroup{0}.size() > 1) {
2535 return emitOpError("{1} group starting at #") << index
2536 << " requires 0 or 1 element, but found " << valueGroup{0}.size();
2537 }
2538 )";
2539 // Check the types of a range of values.
2540 //
2541 // {0}: Value index.
2542 // {1}: Type constraint function.
2543 // {2}: "operand" or "result"
2544 const char *const verifyValues = R"(
2545 for (auto v : valueGroup{0}) {
2546 if (::mlir::failed({1}(*this, v.getType(), "{2}", index++)))
2547 return ::mlir::failure();
2548 }
2549 )";
2550
2551 const auto canSkip = [](const NamedTypeConstraint &value) {
2552 return !value.hasPredicate() && !value.isOptional() &&
2553 !value.isVariadicOfVariadic();
2554 };
2555 if (values.empty() || llvm::all_of(values, canSkip))
2556 return;
2557
2558 FmtContext fctx;
2559
2560 body << " {\n unsigned index = 0; (void)index;\n";
2561
2562 for (const auto &staticValue : llvm::enumerate(values)) {
2563 const NamedTypeConstraint &value = staticValue.value();
2564
2565 bool hasPredicate = value.hasPredicate();
2566 bool isOptional = value.isOptional();
2567 bool isVariadicOfVariadic = value.isVariadicOfVariadic();
2568 if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
2569 continue;
2570 body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
2571 // Capitalize the first letter to match the function name
2572 valueKind.substr(0, 1).upper(), valueKind.substr(1),
2573 staticValue.index());
2574
2575 // If the constraint is optional check that the value group has at most 1
2576 // value.
2577 if (isOptional) {
2578 body << formatv(verifyOptional, staticValue.index(), valueKind);
2579 } else if (isVariadicOfVariadic) {
2580 body << formatv(
2581 " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
2582 "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
2583 " return ::mlir::failure();\n",
2584 value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
2585 staticValue.index());
2586 }
2587
2588 // Otherwise, if there is no predicate there is nothing left to do.
2589 if (!hasPredicate)
2590 continue;
2591 // Emit a loop to check all the dynamic values in the pack.
2592 StringRef constraintFn =
2593 staticVerifierEmitter.getTypeConstraintFn(value.constraint);
2594 body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind);
2595 }
2596
2597 body << " }\n";
2598 }
2599
genRegionVerifier(MethodBody & body)2600 void OpEmitter::genRegionVerifier(MethodBody &body) {
2601 /// Code to verify a region.
2602 ///
2603 /// {0}: Getter for the regions.
2604 /// {1}: The region constraint.
2605 /// {2}: The region's name.
2606 /// {3}: The region description.
2607 const char *const verifyRegion = R"(
2608 for (auto ®ion : {0})
2609 if (::mlir::failed({1}(*this, region, "{2}", index++)))
2610 return ::mlir::failure();
2611 )";
2612 /// Get a single region.
2613 ///
2614 /// {0}: The region's index.
2615 const char *const getSingleRegion =
2616 "::llvm::makeMutableArrayRef((*this)->getRegion({0}))";
2617
2618 // If we have no regions, there is nothing more to do.
2619 const auto canSkip = [](const NamedRegion ®ion) {
2620 return region.constraint.getPredicate().isNull();
2621 };
2622 auto regions = op.getRegions();
2623 if (regions.empty() && llvm::all_of(regions, canSkip))
2624 return;
2625
2626 body << " {\n unsigned index = 0; (void)index;\n";
2627 for (const auto &it : llvm::enumerate(regions)) {
2628 const auto ®ion = it.value();
2629 if (canSkip(region))
2630 continue;
2631
2632 auto getRegion = region.isVariadic()
2633 ? formatv("{0}()", op.getGetterName(region.name)).str()
2634 : formatv(getSingleRegion, it.index()).str();
2635 auto constraintFn =
2636 staticVerifierEmitter.getRegionConstraintFn(region.constraint);
2637 body << formatv(verifyRegion, getRegion, constraintFn, region.name);
2638 }
2639 body << " }\n";
2640 }
2641
genSuccessorVerifier(MethodBody & body)2642 void OpEmitter::genSuccessorVerifier(MethodBody &body) {
2643 const char *const verifySuccessor = R"(
2644 for (auto *successor : {0})
2645 if (::mlir::failed({1}(*this, successor, "{2}", index++)))
2646 return ::mlir::failure();
2647 )";
2648 /// Get a single successor.
2649 ///
2650 /// {0}: The successor's name.
2651 const char *const getSingleSuccessor = "::llvm::makeMutableArrayRef({0}())";
2652
2653 // If we have no successors, there is nothing more to do.
2654 const auto canSkip = [](const NamedSuccessor &successor) {
2655 return successor.constraint.getPredicate().isNull();
2656 };
2657 auto successors = op.getSuccessors();
2658 if (successors.empty() && llvm::all_of(successors, canSkip))
2659 return;
2660
2661 body << " {\n unsigned index = 0; (void)index;\n";
2662
2663 for (auto &it : llvm::enumerate(successors)) {
2664 const auto &successor = it.value();
2665 if (canSkip(successor))
2666 continue;
2667
2668 auto getSuccessor =
2669 formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor,
2670 successor.name, it.index())
2671 .str();
2672 auto constraintFn =
2673 staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint);
2674 body << formatv(verifySuccessor, getSuccessor, constraintFn,
2675 successor.name);
2676 }
2677 body << " }\n";
2678 }
2679
2680 /// Add a size count trait to the given operation class.
addSizeCountTrait(OpClass & opClass,StringRef traitKind,int numTotal,int numVariadic)2681 static void addSizeCountTrait(OpClass &opClass, StringRef traitKind,
2682 int numTotal, int numVariadic) {
2683 if (numVariadic != 0) {
2684 if (numTotal == numVariadic)
2685 opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s");
2686 else
2687 opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" +
2688 Twine(numTotal - numVariadic) + ">::Impl");
2689 return;
2690 }
2691 switch (numTotal) {
2692 case 0:
2693 opClass.addTrait("::mlir::OpTrait::Zero" + traitKind + "s");
2694 break;
2695 case 1:
2696 opClass.addTrait("::mlir::OpTrait::One" + traitKind);
2697 break;
2698 default:
2699 opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) +
2700 ">::Impl");
2701 break;
2702 }
2703 }
2704
genTraits()2705 void OpEmitter::genTraits() {
2706 // Add region size trait.
2707 unsigned numRegions = op.getNumRegions();
2708 unsigned numVariadicRegions = op.getNumVariadicRegions();
2709 addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions);
2710
2711 // Add result size traits.
2712 int numResults = op.getNumResults();
2713 int numVariadicResults = op.getNumVariableLengthResults();
2714 addSizeCountTrait(opClass, "Result", numResults, numVariadicResults);
2715
2716 // For single result ops with a known specific type, generate a OneTypedResult
2717 // trait.
2718 if (numResults == 1 && numVariadicResults == 0) {
2719 auto cppName = op.getResults().begin()->constraint.getCPPClassName();
2720 opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
2721 }
2722
2723 // Add successor size trait.
2724 unsigned numSuccessors = op.getNumSuccessors();
2725 unsigned numVariadicSuccessors = op.getNumVariadicSuccessors();
2726 addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors);
2727
2728 // Add variadic size trait and normal op traits.
2729 int numOperands = op.getNumOperands();
2730 int numVariadicOperands = op.getNumVariableLengthOperands();
2731
2732 // Add operand size trait.
2733 addSizeCountTrait(opClass, "Operand", numOperands, numVariadicOperands);
2734
2735 // The op traits defined internal are ensured that they can be verified
2736 // earlier.
2737 for (const auto &trait : op.getTraits()) {
2738 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
2739 if (opTrait->isStructuralOpTrait())
2740 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2741 }
2742 }
2743
2744 // OpInvariants wrapps the verifyInvariants which needs to be run before
2745 // native/interface traits and after all the traits with `StructuralOpTrait`.
2746 opClass.addTrait("::mlir::OpTrait::OpInvariants");
2747
2748 // Add the native and interface traits.
2749 for (const auto &trait : op.getTraits()) {
2750 if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
2751 if (!opTrait->isStructuralOpTrait())
2752 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2753 } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
2754 opClass.addTrait(opTrait->getFullyQualifiedTraitName());
2755 }
2756 }
2757 }
2758
genOpNameGetter()2759 void OpEmitter::genOpNameGetter() {
2760 auto *method = opClass.addStaticMethod<Method::Constexpr>(
2761 "::llvm::StringLiteral", "getOperationName");
2762 ERROR_IF_PRUNED(method, "getOperationName", op);
2763 method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName()
2764 << "\");";
2765 }
2766
genOpAsmInterface()2767 void OpEmitter::genOpAsmInterface() {
2768 // If the user only has one results or specifically added the Asm trait,
2769 // then don't generate it for them. We specifically only handle multi result
2770 // operations, because the name of a single result in the common case is not
2771 // interesting(generally 'result'/'output'/etc.).
2772 // TODO: We could also add a flag to allow operations to opt in to this
2773 // generation, even if they only have a single operation.
2774 int numResults = op.getNumResults();
2775 if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait"))
2776 return;
2777
2778 SmallVector<StringRef, 4> resultNames(numResults);
2779 for (int i = 0; i != numResults; ++i)
2780 resultNames[i] = op.getResultName(i);
2781
2782 // Don't add the trait if none of the results have a valid name.
2783 if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); }))
2784 return;
2785 opClass.addTrait("::mlir::OpAsmOpInterface::Trait");
2786
2787 // Generate the right accessor for the number of results.
2788 auto *method = opClass.addMethod(
2789 "void", "getAsmResultNames",
2790 MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn"));
2791 ERROR_IF_PRUNED(method, "getAsmResultNames", op);
2792 auto &body = method->body();
2793 for (int i = 0; i != numResults; ++i) {
2794 body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n"
2795 << " if (!llvm::empty(resultGroup" << i << "))\n"
2796 << " setNameFn(*resultGroup" << i << ".begin(), \""
2797 << resultNames[i] << "\");\n";
2798 }
2799 }
2800
2801 //===----------------------------------------------------------------------===//
2802 // OpOperandAdaptor emitter
2803 //===----------------------------------------------------------------------===//
2804
2805 namespace {
2806 // Helper class to emit Op operand adaptors to an output stream. Operand
2807 // adaptors are wrappers around ArrayRef<Value> that provide named operand
2808 // getters identical to those defined in the Op.
2809 class OpOperandAdaptorEmitter {
2810 public:
2811 static void
2812 emitDecl(const Operator &op,
2813 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2814 raw_ostream &os);
2815 static void
2816 emitDef(const Operator &op,
2817 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2818 raw_ostream &os);
2819
2820 private:
2821 explicit OpOperandAdaptorEmitter(
2822 const Operator &op,
2823 const StaticVerifierFunctionEmitter &staticVerifierEmitter);
2824
2825 // Add verification function. This generates a verify method for the adaptor
2826 // which verifies all the op-independent attribute constraints.
2827 void addVerification();
2828
2829 // The operation for which to emit an adaptor.
2830 const Operator &op;
2831
2832 // The generated adaptor class.
2833 Class adaptor;
2834
2835 // The emitter containing all of the locally emitted verification functions.
2836 const StaticVerifierFunctionEmitter &staticVerifierEmitter;
2837
2838 // Helper for emitting adaptor code.
2839 OpOrAdaptorHelper emitHelper;
2840 };
2841 } // namespace
2842
OpOperandAdaptorEmitter(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter)2843 OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
2844 const Operator &op,
2845 const StaticVerifierFunctionEmitter &staticVerifierEmitter)
2846 : op(op), adaptor(op.getAdaptorName()),
2847 staticVerifierEmitter(staticVerifierEmitter),
2848 emitHelper(op, /*emitForOp=*/false) {
2849 adaptor.addField("::mlir::ValueRange", "odsOperands");
2850 adaptor.addField("::mlir::DictionaryAttr", "odsAttrs");
2851 adaptor.addField("::mlir::RegionRange", "odsRegions");
2852 adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName");
2853
2854 const auto *attrSizedOperands =
2855 op.getTrait("::m::OpTrait::AttrSizedOperandSegments");
2856 {
2857 SmallVector<MethodParameter> paramList;
2858 paramList.emplace_back("::mlir::ValueRange", "values");
2859 paramList.emplace_back("::mlir::DictionaryAttr", "attrs",
2860 attrSizedOperands ? "" : "nullptr");
2861 paramList.emplace_back("::mlir::RegionRange", "regions", "{}");
2862 auto *constructor = adaptor.addConstructor(std::move(paramList));
2863
2864 constructor->addMemberInitializer("odsOperands", "values");
2865 constructor->addMemberInitializer("odsAttrs", "attrs");
2866 constructor->addMemberInitializer("odsRegions", "regions");
2867
2868 MethodBody &body = constructor->body();
2869 body.indent() << "if (odsAttrs)\n";
2870 body.indent() << formatv(
2871 "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n",
2872 op.getOperationName());
2873 }
2874
2875 {
2876 auto *constructor =
2877 adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op"));
2878 constructor->addMemberInitializer("odsOperands", "op->getOperands()");
2879 constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()");
2880 constructor->addMemberInitializer("odsRegions", "op->getRegions()");
2881 constructor->addMemberInitializer("odsOpName", "op->getName()");
2882 }
2883
2884 {
2885 auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands");
2886 ERROR_IF_PRUNED(m, "getOperands", op);
2887 m->body() << " return odsOperands;";
2888 }
2889 std::string sizeAttrInit;
2890 if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
2891 sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode,
2892 emitHelper.getAttr(operandSegmentAttrName));
2893 }
2894 generateNamedOperandGetters(op, adaptor,
2895 /*isAdaptor=*/true, sizeAttrInit,
2896 /*rangeType=*/"::mlir::ValueRange",
2897 /*rangeBeginCall=*/"odsOperands.begin()",
2898 /*rangeSizeCall=*/"odsOperands.size()",
2899 /*getOperandCallPattern=*/"odsOperands[{0}]");
2900
2901 FmtContext fctx;
2902 fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())");
2903
2904 // Generate named accessor with Attribute return type.
2905 auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName,
2906 Attribute attr) {
2907 auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr");
2908 ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op);
2909 auto &body = method->body().indent();
2910 body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n"
2911 << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name),
2912 attr.hasDefaultValue() || attr.isOptional()
2913 ? "dyn_cast_or_null"
2914 : "cast",
2915 attr.getStorageType());
2916
2917 if (attr.hasDefaultValue()) {
2918 // Use the default value if attribute is not set.
2919 // TODO: this is inefficient, we are recreating the attribute for every
2920 // call. This should be set instead.
2921 std::string defaultValue = std::string(
2922 tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2923 body << " if (!attr)\n attr = " << defaultValue << ";\n";
2924 }
2925 body << "return attr;\n";
2926 };
2927
2928 {
2929 auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes");
2930 ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op);
2931 m->body() << " return odsAttrs;";
2932 }
2933 for (auto &namedAttr : op.getAttributes()) {
2934 const auto &name = namedAttr.name;
2935 const auto &attr = namedAttr.attr;
2936 if (attr.isDerivedAttr())
2937 continue;
2938 for (const auto &emitName : op.getGetterNames(name)) {
2939 emitAttrWithStorageType(name, emitName, attr);
2940 emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr);
2941 }
2942 }
2943
2944 unsigned numRegions = op.getNumRegions();
2945 if (numRegions > 0) {
2946 auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions");
2947 ERROR_IF_PRUNED(m, "Adaptor::getRegions", op);
2948 m->body() << " return odsRegions;";
2949 }
2950 for (unsigned i = 0; i < numRegions; ++i) {
2951 const auto ®ion = op.getRegion(i);
2952 if (region.name.empty())
2953 continue;
2954
2955 // Generate the accessors for a variadic region.
2956 for (StringRef name : op.getGetterNames(region.name)) {
2957 if (region.isVariadic()) {
2958 auto *m = adaptor.addMethod("::mlir::RegionRange", name);
2959 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2960 m->body() << formatv(" return odsRegions.drop_front({0});", i);
2961 continue;
2962 }
2963
2964 auto *m = adaptor.addMethod("::mlir::Region &", name);
2965 ERROR_IF_PRUNED(m, "Adaptor::" + name, op);
2966 m->body() << formatv(" return *odsRegions[{0}];", i);
2967 }
2968 }
2969
2970 // Add verification function.
2971 addVerification();
2972 adaptor.finalize();
2973 }
2974
addVerification()2975 void OpOperandAdaptorEmitter::addVerification() {
2976 auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify",
2977 MethodParameter("::mlir::Location", "loc"));
2978 ERROR_IF_PRUNED(method, "verify", op);
2979 auto &body = method->body();
2980
2981 FmtContext verifyCtx;
2982 populateSubstitutions(emitHelper, verifyCtx);
2983 genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
2984
2985 body << " return ::mlir::success();";
2986 }
2987
emitDecl(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter,raw_ostream & os)2988 void OpOperandAdaptorEmitter::emitDecl(
2989 const Operator &op,
2990 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2991 raw_ostream &os) {
2992 OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os);
2993 }
2994
emitDef(const Operator & op,const StaticVerifierFunctionEmitter & staticVerifierEmitter,raw_ostream & os)2995 void OpOperandAdaptorEmitter::emitDef(
2996 const Operator &op,
2997 const StaticVerifierFunctionEmitter &staticVerifierEmitter,
2998 raw_ostream &os) {
2999 OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os);
3000 }
3001
3002 // Emits the opcode enum and op classes.
emitOpClasses(const RecordKeeper & recordKeeper,const std::vector<Record * > & defs,raw_ostream & os,bool emitDecl)3003 static void emitOpClasses(const RecordKeeper &recordKeeper,
3004 const std::vector<Record *> &defs, raw_ostream &os,
3005 bool emitDecl) {
3006 // First emit forward declaration for each class, this allows them to refer
3007 // to each others in traits for example.
3008 if (emitDecl) {
3009 os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n";
3010 os << "#undef GET_OP_FWD_DEFINES\n";
3011 for (auto *def : defs) {
3012 Operator op(*def);
3013 NamespaceEmitter emitter(os, op.getCppNamespace());
3014 os << "class " << op.getCppClassName() << ";\n";
3015 }
3016 os << "#endif\n\n";
3017 }
3018
3019 IfDefScope scope("GET_OP_CLASSES", os);
3020 if (defs.empty())
3021 return;
3022
3023 // Generate all of the locally instantiated methods first.
3024 StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper);
3025 os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
3026 staticVerifierEmitter.emitOpConstraints(defs, emitDecl);
3027
3028 for (auto *def : defs) {
3029 Operator op(*def);
3030 if (emitDecl) {
3031 {
3032 NamespaceEmitter emitter(os, op.getCppNamespace());
3033 os << formatv(opCommentHeader, op.getQualCppClassName(),
3034 "declarations");
3035 OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os);
3036 OpEmitter::emitDecl(op, os, staticVerifierEmitter);
3037 }
3038 // Emit the TypeID explicit specialization to have a single definition.
3039 if (!op.getCppNamespace().empty())
3040 os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
3041 << "::" << op.getCppClassName() << ")\n\n";
3042 } else {
3043 {
3044 NamespaceEmitter emitter(os, op.getCppNamespace());
3045 os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
3046 OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os);
3047 OpEmitter::emitDef(op, os, staticVerifierEmitter);
3048 }
3049 // Emit the TypeID explicit specialization to have a single definition.
3050 if (!op.getCppNamespace().empty())
3051 os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace()
3052 << "::" << op.getCppClassName() << ")\n\n";
3053 }
3054 }
3055 }
3056
3057 // Emits a comma-separated list of the ops.
emitOpList(const std::vector<Record * > & defs,raw_ostream & os)3058 static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
3059 IfDefScope scope("GET_OP_LIST", os);
3060
3061 interleave(
3062 // TODO: We are constructing the Operator wrapper instance just for
3063 // getting it's qualified class name here. Reduce the overhead by having a
3064 // lightweight version of Operator class just for that purpose.
3065 defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); },
3066 [&os]() { os << ",\n"; });
3067 }
3068
emitOpDecls(const RecordKeeper & recordKeeper,raw_ostream & os)3069 static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
3070 emitSourceFileHeader("Op Declarations", os);
3071
3072 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
3073 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
3074
3075 return false;
3076 }
3077
emitOpDefs(const RecordKeeper & recordKeeper,raw_ostream & os)3078 static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
3079 emitSourceFileHeader("Op Definitions", os);
3080
3081 std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
3082 emitOpList(defs, os);
3083 emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
3084
3085 return false;
3086 }
3087
3088 static mlir::GenRegistration
3089 genOpDecls("gen-op-decls", "Generate op declarations",
__anon5e0e87942302(const RecordKeeper &records, raw_ostream &os) 3090 [](const RecordKeeper &records, raw_ostream &os) {
3091 return emitOpDecls(records, os);
3092 });
3093
3094 static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions",
3095 [](const RecordKeeper &records,
__anon5e0e87942402(const RecordKeeper &records, raw_ostream &os) 3096 raw_ostream &os) {
3097 return emitOpDefs(records, os);
3098 });
3099