1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinAttributes.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/ExtensibleDialect.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OperationSupport.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/IR/TypeUtilities.h"
26 #include "mlir/IR/Verifier.h"
27 #include "mlir/Interfaces/InferIntRangeInterface.h"
28 #include "mlir/Reducer/ReductionPatternInterface.h"
29 #include "mlir/Transforms/FoldUtils.h"
30 #include "mlir/Transforms/InliningUtils.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSwitch.h"
34 
35 // Include this before the using namespace lines below to
36 // test that we don't have namespace dependencies.
37 #include "TestOpsDialect.cpp.inc"
38 
39 using namespace mlir;
40 using namespace test;
41 
42 void test::registerTestDialect(DialectRegistry &registry) {
43   registry.insert<TestDialect>();
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // TestDialect Interfaces
48 //===----------------------------------------------------------------------===//
49 
50 namespace {
51 
52 /// Testing the correctness of some traits.
53 static_assert(
54     llvm::is_detected<OpTrait::has_implicit_terminator_t,
55                       SingleBlockImplicitTerminatorOp>::value,
56     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
57 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
58                   SingleBlockImplicitTerminatorOp>::value,
59               "hasSingleBlockImplicitTerminator does not match "
60               "SingleBlockImplicitTerminatorOp");
61 
62 // Test support for interacting with the AsmPrinter.
63 struct TestOpAsmInterface : public OpAsmDialectInterface {
64   using OpAsmDialectInterface::OpAsmDialectInterface;
65 
66   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
67     StringAttr strAttr = attr.dyn_cast<StringAttr>();
68     if (!strAttr)
69       return AliasResult::NoAlias;
70 
71     // Check the contents of the string attribute to see what the test alias
72     // should be named.
73     Optional<StringRef> aliasName =
74         StringSwitch<Optional<StringRef>>(strAttr.getValue())
75             .Case("alias_test:dot_in_name", StringRef("test.alias"))
76             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
77             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
78             .Case("alias_test:sanitize_conflict_a",
79                   StringRef("test_alias_conflict0"))
80             .Case("alias_test:sanitize_conflict_b",
81                   StringRef("test_alias_conflict0_"))
82             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
83             .Default(llvm::None);
84     if (!aliasName)
85       return AliasResult::NoAlias;
86 
87     os << *aliasName;
88     return AliasResult::FinalAlias;
89   }
90 
91   AliasResult getAlias(Type type, raw_ostream &os) const final {
92     if (auto tupleType = type.dyn_cast<TupleType>()) {
93       if (tupleType.size() > 0 &&
94           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
95             return elemType.isa<SimpleAType>();
96           })) {
97         os << "test_tuple";
98         return AliasResult::FinalAlias;
99       }
100     }
101     if (auto intType = type.dyn_cast<TestIntegerType>()) {
102       if (intType.getSignedness() ==
103               TestIntegerType::SignednessSemantics::Unsigned &&
104           intType.getWidth() == 8) {
105         os << "test_ui8";
106         return AliasResult::FinalAlias;
107       }
108     }
109     return AliasResult::NoAlias;
110   }
111 };
112 
113 struct TestDialectFoldInterface : public DialectFoldInterface {
114   using DialectFoldInterface::DialectFoldInterface;
115 
116   /// Registered hook to check if the given region, which is attached to an
117   /// operation that is *not* isolated from above, should be used when
118   /// materializing constants.
119   bool shouldMaterializeInto(Region *region) const final {
120     // If this is a one region operation, then insert into it.
121     return isa<OneRegionOp>(region->getParentOp());
122   }
123 };
124 
125 /// This class defines the interface for handling inlining with standard
126 /// operations.
127 struct TestInlinerInterface : public DialectInlinerInterface {
128   using DialectInlinerInterface::DialectInlinerInterface;
129 
130   //===--------------------------------------------------------------------===//
131   // Analysis Hooks
132   //===--------------------------------------------------------------------===//
133 
134   bool isLegalToInline(Operation *call, Operation *callable,
135                        bool wouldBeCloned) const final {
136     // Don't allow inlining calls that are marked `noinline`.
137     return !call->hasAttr("noinline");
138   }
139   bool isLegalToInline(Region *, Region *, bool,
140                        BlockAndValueMapping &) const final {
141     // Inlining into test dialect regions is legal.
142     return true;
143   }
144   bool isLegalToInline(Operation *, Region *, bool,
145                        BlockAndValueMapping &) const final {
146     return true;
147   }
148 
149   bool shouldAnalyzeRecursively(Operation *op) const final {
150     // Analyze recursively if this is not a functional region operation, it
151     // froms a separate functional scope.
152     return !isa<FunctionalRegionOp>(op);
153   }
154 
155   //===--------------------------------------------------------------------===//
156   // Transformation Hooks
157   //===--------------------------------------------------------------------===//
158 
159   /// Handle the given inlined terminator by replacing it with a new operation
160   /// as necessary.
161   void handleTerminator(Operation *op,
162                         ArrayRef<Value> valuesToRepl) const final {
163     // Only handle "test.return" here.
164     auto returnOp = dyn_cast<TestReturnOp>(op);
165     if (!returnOp)
166       return;
167 
168     // Replace the values directly with the return operands.
169     assert(returnOp.getNumOperands() == valuesToRepl.size());
170     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
171       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
172   }
173 
174   /// Attempt to materialize a conversion for a type mismatch between a call
175   /// from this dialect, and a callable region. This method should generate an
176   /// operation that takes 'input' as the only operand, and produces a single
177   /// result of 'resultType'. If a conversion can not be generated, nullptr
178   /// should be returned.
179   Operation *materializeCallConversion(OpBuilder &builder, Value input,
180                                        Type resultType,
181                                        Location conversionLoc) const final {
182     // Only allow conversion for i16/i32 types.
183     if (!(resultType.isSignlessInteger(16) ||
184           resultType.isSignlessInteger(32)) ||
185         !(input.getType().isSignlessInteger(16) ||
186           input.getType().isSignlessInteger(32)))
187       return nullptr;
188     return builder.create<TestCastOp>(conversionLoc, resultType, input);
189   }
190 
191   void processInlinedCallBlocks(
192       Operation *call,
193       iterator_range<Region::iterator> inlinedBlocks) const final {
194     if (!isa<ConversionCallOp>(call))
195       return;
196 
197     // Set attributed on all ops in the inlined blocks.
198     for (Block &block : inlinedBlocks) {
199       block.walk([&](Operation *op) {
200         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
201       });
202     }
203   }
204 };
205 
206 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
207 public:
208   TestReductionPatternInterface(Dialect *dialect)
209       : DialectReductionPatternInterface(dialect) {}
210 
211   void populateReductionPatterns(RewritePatternSet &patterns) const final {
212     populateTestReductionPatterns(patterns);
213   }
214 };
215 
216 } // namespace
217 
218 //===----------------------------------------------------------------------===//
219 // Dynamic operations
220 //===----------------------------------------------------------------------===//
221 
222 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
223   return DynamicOpDefinition::get(
224       "dynamic_generic", dialect, [](Operation *op) { return success(); },
225       [](Operation *op) { return success(); });
226 }
227 
228 std::unique_ptr<DynamicOpDefinition>
229 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
230   return DynamicOpDefinition::get(
231       "dynamic_one_operand_two_results", dialect,
232       [](Operation *op) {
233         if (op->getNumOperands() != 1) {
234           op->emitOpError()
235               << "expected 1 operand, but had " << op->getNumOperands();
236           return failure();
237         }
238         if (op->getNumResults() != 2) {
239           op->emitOpError()
240               << "expected 2 results, but had " << op->getNumResults();
241           return failure();
242         }
243         return success();
244       },
245       [](Operation *op) { return success(); });
246 }
247 
248 std::unique_ptr<DynamicOpDefinition>
249 getDynamicCustomParserPrinterOp(TestDialect *dialect) {
250   auto verifier = [](Operation *op) {
251     if (op->getNumOperands() == 0 && op->getNumResults() == 0)
252       return success();
253     op->emitError() << "operation should have no operands and no results";
254     return failure();
255   };
256   auto regionVerifier = [](Operation *op) { return success(); };
257 
258   auto parser = [](OpAsmParser &parser, OperationState &state) {
259     return parser.parseKeyword("custom_keyword");
260   };
261 
262   auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
263     printer << op->getName() << " custom_keyword";
264   };
265 
266   return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
267                                   verifier, regionVerifier, parser, printer);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // TestDialect
272 //===----------------------------------------------------------------------===//
273 
274 static void testSideEffectOpGetEffect(
275     Operation *op,
276     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
277 
278 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
279 struct TestOpEffectInterfaceFallback
280     : public TestEffectOpInterface::FallbackModel<
281           TestOpEffectInterfaceFallback> {
282   static bool classof(Operation *op) {
283     bool isSupportedOp =
284         op->getName().getStringRef() == "test.unregistered_side_effect_op";
285     assert(isSupportedOp && "Unexpected dispatch");
286     return isSupportedOp;
287   }
288 
289   void
290   getEffects(Operation *op,
291              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
292                  &effects) const {
293     testSideEffectOpGetEffect(op, effects);
294   }
295 };
296 
297 void TestDialect::initialize() {
298   registerAttributes();
299   registerTypes();
300   addOperations<
301 #define GET_OP_LIST
302 #include "TestOps.cpp.inc"
303       >();
304   registerDynamicOp(getDynamicGenericOp(this));
305   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
306   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
307 
308   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
309                 TestInlinerInterface, TestReductionPatternInterface>();
310   allowUnknownOperations();
311 
312   // Instantiate our fallback op interface that we'll use on specific
313   // unregistered op.
314   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
315 }
316 TestDialect::~TestDialect() {
317   delete static_cast<TestOpEffectInterfaceFallback *>(
318       fallbackEffectOpInterfaces);
319 }
320 
321 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
322                                             Type type, Location loc) {
323   return builder.create<TestOpConstant>(loc, type, value);
324 }
325 
326 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
327     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
328     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
329     ::mlir::RegionRange regions,
330     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
331   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
332   return ::mlir::success();
333 }
334 
335 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
336                                                OperationName opName) {
337   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
338       typeID == TypeID::get<TestEffectOpInterface>())
339     return fallbackEffectOpInterfaces;
340   return nullptr;
341 }
342 
343 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
344                                                     NamedAttribute namedAttr) {
345   if (namedAttr.getName() == "test.invalid_attr")
346     return op->emitError() << "invalid to use 'test.invalid_attr'";
347   return success();
348 }
349 
350 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
351                                                     unsigned regionIndex,
352                                                     unsigned argIndex,
353                                                     NamedAttribute namedAttr) {
354   if (namedAttr.getName() == "test.invalid_attr")
355     return op->emitError() << "invalid to use 'test.invalid_attr'";
356   return success();
357 }
358 
359 LogicalResult
360 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
361                                          unsigned resultIndex,
362                                          NamedAttribute namedAttr) {
363   if (namedAttr.getName() == "test.invalid_attr")
364     return op->emitError() << "invalid to use 'test.invalid_attr'";
365   return success();
366 }
367 
368 Optional<Dialect::ParseOpHook>
369 TestDialect::getParseOperationHook(StringRef opName) const {
370   if (opName == "test.dialect_custom_printer") {
371     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
372       return parser.parseKeyword("custom_format");
373     }};
374   }
375   if (opName == "test.dialect_custom_format_fallback") {
376     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
377       return parser.parseKeyword("custom_format_fallback");
378     }};
379   }
380   if (opName == "test.dialect_custom_printer.with.dot") {
381     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
382       return ParseResult::success();
383     }};
384   }
385   return None;
386 }
387 
388 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
389 TestDialect::getOperationPrinter(Operation *op) const {
390   StringRef opName = op->getName().getStringRef();
391   if (opName == "test.dialect_custom_printer") {
392     return [](Operation *op, OpAsmPrinter &printer) {
393       printer.getStream() << " custom_format";
394     };
395   }
396   if (opName == "test.dialect_custom_format_fallback") {
397     return [](Operation *op, OpAsmPrinter &printer) {
398       printer.getStream() << " custom_format_fallback";
399     };
400   }
401   return {};
402 }
403 
404 //===----------------------------------------------------------------------===//
405 // TestBranchOp
406 //===----------------------------------------------------------------------===//
407 
408 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
409   assert(index == 0 && "invalid successor index");
410   return SuccessorOperands(getTargetOperandsMutable());
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // TestProducingBranchOp
415 //===----------------------------------------------------------------------===//
416 
417 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
418   assert(index <= 1 && "invalid successor index");
419   if (index == 1)
420     return SuccessorOperands(getFirstOperandsMutable());
421   return SuccessorOperands(getSecondOperandsMutable());
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // TestProducingBranchOp
426 //===----------------------------------------------------------------------===//
427 
428 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
429   assert(index <= 1 && "invalid successor index");
430   if (index == 0)
431     return SuccessorOperands(0, getSuccessOperandsMutable());
432   return SuccessorOperands(1, getErrorOperandsMutable());
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // TestDialectCanonicalizerOp
437 //===----------------------------------------------------------------------===//
438 
439 static LogicalResult
440 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
441                                PatternRewriter &rewriter) {
442   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
443       op, rewriter.getI32IntegerAttr(42));
444   return success();
445 }
446 
447 void TestDialect::getCanonicalizationPatterns(
448     RewritePatternSet &results) const {
449   results.add(&dialectCanonicalizationPattern);
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // TestCallOp
454 //===----------------------------------------------------------------------===//
455 
456 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
457   // Check that the callee attribute was specified.
458   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
459   if (!fnAttr)
460     return emitOpError("requires a 'callee' symbol reference attribute");
461   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
462     return emitOpError() << "'" << fnAttr.getValue()
463                          << "' does not reference a valid function";
464   return success();
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // TestFoldToCallOp
469 //===----------------------------------------------------------------------===//
470 
471 namespace {
472 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
473   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
474 
475   LogicalResult matchAndRewrite(FoldToCallOp op,
476                                 PatternRewriter &rewriter) const override {
477     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
478                                               op.getCalleeAttr(), ValueRange());
479     return success();
480   }
481 };
482 } // namespace
483 
484 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
485                                                MLIRContext *context) {
486   results.add<FoldToCallOpPattern>(context);
487 }
488 
489 //===----------------------------------------------------------------------===//
490 // Test Format* operations
491 //===----------------------------------------------------------------------===//
492 
493 //===----------------------------------------------------------------------===//
494 // Parsing
495 
496 static ParseResult parseCustomOptionalOperand(
497     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
498   if (succeeded(parser.parseOptionalLParen())) {
499     optOperand.emplace();
500     if (parser.parseOperand(*optOperand) || parser.parseRParen())
501       return failure();
502   }
503   return success();
504 }
505 
506 static ParseResult parseCustomDirectiveOperands(
507     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
508     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
509     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
510   if (parser.parseOperand(operand))
511     return failure();
512   if (succeeded(parser.parseOptionalComma())) {
513     optOperand.emplace();
514     if (parser.parseOperand(*optOperand))
515       return failure();
516   }
517   if (parser.parseArrow() || parser.parseLParen() ||
518       parser.parseOperandList(varOperands) || parser.parseRParen())
519     return failure();
520   return success();
521 }
522 static ParseResult
523 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
524                             Type &optOperandType,
525                             SmallVectorImpl<Type> &varOperandTypes) {
526   if (parser.parseColon())
527     return failure();
528 
529   if (parser.parseType(operandType))
530     return failure();
531   if (succeeded(parser.parseOptionalComma())) {
532     if (parser.parseType(optOperandType))
533       return failure();
534   }
535   if (parser.parseArrow() || parser.parseLParen() ||
536       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
537     return failure();
538   return success();
539 }
540 static ParseResult
541 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
542                                  Type optOperandType,
543                                  const SmallVectorImpl<Type> &varOperandTypes) {
544   if (parser.parseKeyword("type_refs_capture"))
545     return failure();
546 
547   Type operandType2, optOperandType2;
548   SmallVector<Type, 1> varOperandTypes2;
549   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
550                                   varOperandTypes2))
551     return failure();
552 
553   if (operandType != operandType2 || optOperandType != optOperandType2 ||
554       varOperandTypes != varOperandTypes2)
555     return failure();
556 
557   return success();
558 }
559 static ParseResult parseCustomDirectiveOperandsAndTypes(
560     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
561     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
562     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
563     Type &operandType, Type &optOperandType,
564     SmallVectorImpl<Type> &varOperandTypes) {
565   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
566       parseCustomDirectiveResults(parser, operandType, optOperandType,
567                                   varOperandTypes))
568     return failure();
569   return success();
570 }
571 static ParseResult parseCustomDirectiveRegions(
572     OpAsmParser &parser, Region &region,
573     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
574   if (parser.parseRegion(region))
575     return failure();
576   if (failed(parser.parseOptionalComma()))
577     return success();
578   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
579   if (parser.parseRegion(*varRegion))
580     return failure();
581   varRegions.emplace_back(std::move(varRegion));
582   return success();
583 }
584 static ParseResult
585 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
586                                SmallVectorImpl<Block *> &varSuccessors) {
587   if (parser.parseSuccessor(successor))
588     return failure();
589   if (failed(parser.parseOptionalComma()))
590     return success();
591   Block *varSuccessor;
592   if (parser.parseSuccessor(varSuccessor))
593     return failure();
594   varSuccessors.append(2, varSuccessor);
595   return success();
596 }
597 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
598                                                   IntegerAttr &attr,
599                                                   IntegerAttr &optAttr) {
600   if (parser.parseAttribute(attr))
601     return failure();
602   if (succeeded(parser.parseOptionalComma())) {
603     if (parser.parseAttribute(optAttr))
604       return failure();
605   }
606   return success();
607 }
608 
609 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
610                                                 NamedAttrList &attrs) {
611   return parser.parseOptionalAttrDict(attrs);
612 }
613 static ParseResult parseCustomDirectiveOptionalOperandRef(
614     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
615   int64_t operandCount = 0;
616   if (parser.parseInteger(operandCount))
617     return failure();
618   bool expectedOptionalOperand = operandCount == 0;
619   return success(expectedOptionalOperand != optOperand.has_value());
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // Printing
624 
625 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
626                                        Value optOperand) {
627   if (optOperand)
628     printer << "(" << optOperand << ") ";
629 }
630 
631 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
632                                          Value operand, Value optOperand,
633                                          OperandRange varOperands) {
634   printer << operand;
635   if (optOperand)
636     printer << ", " << optOperand;
637   printer << " -> (" << varOperands << ")";
638 }
639 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
640                                         Type operandType, Type optOperandType,
641                                         TypeRange varOperandTypes) {
642   printer << " : " << operandType;
643   if (optOperandType)
644     printer << ", " << optOperandType;
645   printer << " -> (" << varOperandTypes << ")";
646 }
647 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
648                                              Operation *op, Type operandType,
649                                              Type optOperandType,
650                                              TypeRange varOperandTypes) {
651   printer << " type_refs_capture ";
652   printCustomDirectiveResults(printer, op, operandType, optOperandType,
653                               varOperandTypes);
654 }
655 static void printCustomDirectiveOperandsAndTypes(
656     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
657     OperandRange varOperands, Type operandType, Type optOperandType,
658     TypeRange varOperandTypes) {
659   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
660   printCustomDirectiveResults(printer, op, operandType, optOperandType,
661                               varOperandTypes);
662 }
663 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
664                                         Region &region,
665                                         MutableArrayRef<Region> varRegions) {
666   printer.printRegion(region);
667   if (!varRegions.empty()) {
668     printer << ", ";
669     for (Region &region : varRegions)
670       printer.printRegion(region);
671   }
672 }
673 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
674                                            Block *successor,
675                                            SuccessorRange varSuccessors) {
676   printer << successor;
677   if (!varSuccessors.empty())
678     printer << ", " << varSuccessors.front();
679 }
680 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
681                                            Attribute attribute,
682                                            Attribute optAttribute) {
683   printer << attribute;
684   if (optAttribute)
685     printer << ", " << optAttribute;
686 }
687 
688 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
689                                          DictionaryAttr attrs) {
690   printer.printOptionalAttrDict(attrs.getValue());
691 }
692 
693 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
694                                                    Operation *op,
695                                                    Value optOperand) {
696   printer << (optOperand ? "1" : "0");
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // Test IsolatedRegionOp - parse passthrough region arguments.
701 //===----------------------------------------------------------------------===//
702 
703 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
704                                     OperationState &result) {
705   // Parse the input operand.
706   OpAsmParser::Argument argInfo;
707   argInfo.type = parser.getBuilder().getIndexType();
708   if (parser.parseOperand(argInfo.ssaName) ||
709       parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
710     return failure();
711 
712   // Parse the body region, and reuse the operand info as the argument info.
713   Region *body = result.addRegion();
714   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
715 }
716 
717 void IsolatedRegionOp::print(OpAsmPrinter &p) {
718   p << "test.isolated_region ";
719   p.printOperand(getOperand());
720   p.shadowRegionArgs(getRegion(), getOperand());
721   p << ' ';
722   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // Test SSACFGRegionOp
727 //===----------------------------------------------------------------------===//
728 
729 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
730   return RegionKind::SSACFG;
731 }
732 
733 //===----------------------------------------------------------------------===//
734 // Test GraphRegionOp
735 //===----------------------------------------------------------------------===//
736 
737 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
738   return RegionKind::Graph;
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // Test AffineScopeOp
743 //===----------------------------------------------------------------------===//
744 
745 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
746   // Parse the body region, and reuse the operand info as the argument info.
747   Region *body = result.addRegion();
748   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
749 }
750 
751 void AffineScopeOp::print(OpAsmPrinter &p) {
752   p << "test.affine_scope ";
753   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // Test parser.
758 //===----------------------------------------------------------------------===//
759 
760 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
761                                          OperationState &result) {
762   if (parser.parseOptionalColon())
763     return success();
764   uint64_t numResults;
765   if (parser.parseInteger(numResults))
766     return failure();
767 
768   IndexType type = parser.getBuilder().getIndexType();
769   for (unsigned i = 0; i < numResults; ++i)
770     result.addTypes(type);
771   return success();
772 }
773 
774 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
775   if (unsigned numResults = getNumResults())
776     p << " : " << numResults;
777 }
778 
779 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
780                                          OperationState &result) {
781   StringRef keyword;
782   if (parser.parseKeyword(&keyword))
783     return failure();
784   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
785   return success();
786 }
787 
788 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
789 
790 //===----------------------------------------------------------------------===//
791 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
792 
793 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
794                                     OperationState &result) {
795   if (parser.parseKeyword("wraps"))
796     return failure();
797 
798   // Parse the wrapped op in a region
799   Region &body = *result.addRegion();
800   body.push_back(new Block);
801   Block &block = body.back();
802   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
803   if (!wrappedOp)
804     return failure();
805 
806   // Create a return terminator in the inner region, pass as operand to the
807   // terminator the returned values from the wrapped operation.
808   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
809   OpBuilder builder(parser.getContext());
810   builder.setInsertionPointToEnd(&block);
811   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
812 
813   // Get the results type for the wrapping op from the terminator operands.
814   Operation &returnOp = body.back().back();
815   result.types.append(returnOp.operand_type_begin(),
816                       returnOp.operand_type_end());
817 
818   // Use the location of the wrapped op for the "test.wrapping_region" op.
819   result.location = wrappedOp->getLoc();
820 
821   return success();
822 }
823 
824 void WrappingRegionOp::print(OpAsmPrinter &p) {
825   p << " wraps ";
826   p.printGenericOp(&getRegion().front().front());
827 }
828 
829 //===----------------------------------------------------------------------===//
830 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
831 //   parseGenericOperationAfterOpName
832 //   parseCustomOperationName
833 //===----------------------------------------------------------------------===//
834 
835 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
836                                          OperationState &result) {
837 
838   SMLoc loc = parser.getCurrentLocation();
839   Location currLocation = parser.getEncodedSourceLoc(loc);
840 
841   // Parse the operands.
842   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
843   if (parser.parseOperandList(operands))
844     return failure();
845 
846   // Check if we are parsing the pretty-printed version
847   //  test.pretty_printed_region start <inner-op> end : <functional-type>
848   // Else fallback to parsing the "non pretty-printed" version.
849   if (!succeeded(parser.parseOptionalKeyword("start")))
850     return parser.parseGenericOperationAfterOpName(
851         result, llvm::makeArrayRef(operands));
852 
853   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
854   if (failed(parseOpNameInfo))
855     return failure();
856 
857   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
858 
859   FunctionType opFntype;
860   Optional<Location> explicitLoc;
861   if (parser.parseKeyword("end") || parser.parseColon() ||
862       parser.parseType(opFntype) ||
863       parser.parseOptionalLocationSpecifier(explicitLoc))
864     return failure();
865 
866   // If location of the op is explicitly provided, then use it; Else use
867   // the parser's current location.
868   Location opLoc = explicitLoc.value_or(currLocation);
869 
870   // Derive the SSA-values for op's operands.
871   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
872                              result.operands))
873     return failure();
874 
875   // Add a region for op.
876   Region &region = *result.addRegion();
877 
878   // Create a basic-block inside op's region.
879   Block &block = region.emplaceBlock();
880 
881   // Create and insert an "inner-op" operation in the block.
882   // Just for testing purposes, we can assume that inner op is a binary op with
883   // result and operand types all same as the test-op's first operand.
884   Type innerOpType = opFntype.getInput(0);
885   Value lhs = block.addArgument(innerOpType, opLoc);
886   Value rhs = block.addArgument(innerOpType, opLoc);
887 
888   OpBuilder builder(parser.getBuilder().getContext());
889   builder.setInsertionPointToStart(&block);
890 
891   Operation *innerOp =
892       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
893 
894   // Insert a return statement in the block returning the inner-op's result.
895   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
896 
897   // Populate the op operation-state with result-type and location.
898   result.addTypes(opFntype.getResults());
899   result.location = innerOp->getLoc();
900 
901   return success();
902 }
903 
904 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
905   p << ' ';
906   p.printOperands(getOperands());
907 
908   Operation &innerOp = getRegion().front().front();
909   // Assuming that region has a single non-terminator inner-op, if the inner-op
910   // meets some criteria (which in this case is a simple one  based on the name
911   // of inner-op), then we can print the entire region in a succinct way.
912   // Here we assume that the prototype of "special.op" can be trivially derived
913   // while parsing it back.
914   if (innerOp.getName().getStringRef().equals("special.op")) {
915     p << " start special.op end";
916   } else {
917     p << " (";
918     p.printRegion(getRegion());
919     p << ")";
920   }
921 
922   p << " : ";
923   p.printFunctionalType(*this);
924 }
925 
926 //===----------------------------------------------------------------------===//
927 // Test PolyForOp - parse list of region arguments.
928 //===----------------------------------------------------------------------===//
929 
930 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
931   SmallVector<OpAsmParser::Argument, 4> ivsInfo;
932   // Parse list of region arguments without a delimiter.
933   if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
934     return failure();
935 
936   // Parse the body region.
937   Region *body = result.addRegion();
938   for (auto &iv : ivsInfo)
939     iv.type = parser.getBuilder().getIndexType();
940   return parser.parseRegion(*body, ivsInfo);
941 }
942 
943 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
944 
945 void PolyForOp::getAsmBlockArgumentNames(Region &region,
946                                          OpAsmSetValueNameFn setNameFn) {
947   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
948   if (!arrayAttr)
949     return;
950   auto args = getRegion().front().getArguments();
951   auto e = std::min(arrayAttr.size(), args.size());
952   for (unsigned i = 0; i < e; ++i) {
953     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
954       setNameFn(args[i], strAttr.getValue());
955   }
956 }
957 
958 //===----------------------------------------------------------------------===//
959 // Test removing op with inner ops.
960 //===----------------------------------------------------------------------===//
961 
962 namespace {
963 struct TestRemoveOpWithInnerOps
964     : public OpRewritePattern<TestOpWithRegionPattern> {
965   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
966 
967   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
968 
969   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
970                                 PatternRewriter &rewriter) const override {
971     rewriter.eraseOp(op);
972     return success();
973   }
974 };
975 } // namespace
976 
977 void TestOpWithRegionPattern::getCanonicalizationPatterns(
978     RewritePatternSet &results, MLIRContext *context) {
979   results.add<TestRemoveOpWithInnerOps>(context);
980 }
981 
982 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
983   return getOperand();
984 }
985 
986 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
987   return getValue();
988 }
989 
990 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
991     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
992   for (Value input : this->getOperands()) {
993     results.push_back(input);
994   }
995   return success();
996 }
997 
998 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
999   assert(operands.size() == 1);
1000   if (operands.front()) {
1001     (*this)->setAttr("attr", operands.front());
1002     return getResult();
1003   }
1004   return {};
1005 }
1006 
1007 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
1008   return getOperand();
1009 }
1010 
1011 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
1012     MLIRContext *, Optional<Location> location, ValueRange operands,
1013     DictionaryAttr attributes, RegionRange regions,
1014     SmallVectorImpl<Type> &inferredReturnTypes) {
1015   if (operands[0].getType() != operands[1].getType()) {
1016     return emitOptionalError(location, "operand type mismatch ",
1017                              operands[0].getType(), " vs ",
1018                              operands[1].getType());
1019   }
1020   inferredReturnTypes.assign({operands[0].getType()});
1021   return success();
1022 }
1023 
1024 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
1025     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1026     DictionaryAttr attributes, RegionRange regions,
1027     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1028   // Create return type consisting of the last element of the first operand.
1029   auto operandType = operands.front().getType();
1030   auto sval = operandType.dyn_cast<ShapedType>();
1031   if (!sval) {
1032     return emitOptionalError(location, "only shaped type operands allowed");
1033   }
1034   int64_t dim =
1035       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1036   auto type = IntegerType::get(context, 17);
1037   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1038   return success();
1039 }
1040 
1041 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1042     OpBuilder &builder, ValueRange operands,
1043     llvm::SmallVectorImpl<Value> &shapes) {
1044   shapes = SmallVector<Value, 1>{
1045       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1046   return success();
1047 }
1048 
1049 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1050     OpBuilder &builder, ValueRange operands,
1051     llvm::SmallVectorImpl<Value> &shapes) {
1052   Location loc = getLoc();
1053   shapes.reserve(operands.size());
1054   for (Value operand : llvm::reverse(operands)) {
1055     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1056     auto currShape = llvm::to_vector<4>(
1057         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1058           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1059         }));
1060     shapes.push_back(builder.create<tensor::FromElementsOp>(
1061         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1062         currShape));
1063   }
1064   return success();
1065 }
1066 
1067 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1068     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1069   Location loc = getLoc();
1070   shapes.reserve(getNumOperands());
1071   for (Value operand : llvm::reverse(getOperands())) {
1072     auto currShape = llvm::to_vector<4>(llvm::map_range(
1073         llvm::seq<int64_t>(
1074             0, operand.getType().cast<RankedTensorType>().getRank()),
1075         [&](int64_t dim) -> Value {
1076           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1077         }));
1078     shapes.emplace_back(std::move(currShape));
1079   }
1080   return success();
1081 }
1082 
1083 //===----------------------------------------------------------------------===//
1084 // Test SideEffect interfaces
1085 //===----------------------------------------------------------------------===//
1086 
1087 namespace {
1088 /// A test resource for side effects.
1089 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1090   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1091 
1092   StringRef getName() final { return "<Test>"; }
1093 };
1094 } // namespace
1095 
1096 static void testSideEffectOpGetEffect(
1097     Operation *op,
1098     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1099         &effects) {
1100   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1101   if (!effectsAttr)
1102     return;
1103 
1104   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1105 }
1106 
1107 void SideEffectOp::getEffects(
1108     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1109   // Check for an effects attribute on the op instance.
1110   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1111   if (!effectsAttr)
1112     return;
1113 
1114   // If there is one, it is an array of dictionary attributes that hold
1115   // information on the effects of this operation.
1116   for (Attribute element : effectsAttr) {
1117     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1118 
1119     // Get the specific memory effect.
1120     MemoryEffects::Effect *effect =
1121         StringSwitch<MemoryEffects::Effect *>(
1122             effectElement.get("effect").cast<StringAttr>().getValue())
1123             .Case("allocate", MemoryEffects::Allocate::get())
1124             .Case("free", MemoryEffects::Free::get())
1125             .Case("read", MemoryEffects::Read::get())
1126             .Case("write", MemoryEffects::Write::get());
1127 
1128     // Check for a non-default resource to use.
1129     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1130     if (effectElement.get("test_resource"))
1131       resource = TestResource::get();
1132 
1133     // Check for a result to affect.
1134     if (effectElement.get("on_result"))
1135       effects.emplace_back(effect, getResult(), resource);
1136     else if (Attribute ref = effectElement.get("on_reference"))
1137       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1138     else
1139       effects.emplace_back(effect, resource);
1140   }
1141 }
1142 
1143 void SideEffectOp::getEffects(
1144     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1145   testSideEffectOpGetEffect(getOperation(), effects);
1146 }
1147 
1148 //===----------------------------------------------------------------------===//
1149 // StringAttrPrettyNameOp
1150 //===----------------------------------------------------------------------===//
1151 
1152 // This op has fancy handling of its SSA result name.
1153 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1154                                           OperationState &result) {
1155   // Add the result types.
1156   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1157     result.addTypes(parser.getBuilder().getIntegerType(32));
1158 
1159   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1160     return failure();
1161 
1162   // If the attribute dictionary contains no 'names' attribute, infer it from
1163   // the SSA name (if specified).
1164   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1165     return attr.getName() == "names";
1166   });
1167 
1168   // If there was no name specified, check to see if there was a useful name
1169   // specified in the asm file.
1170   if (hadNames || parser.getNumResults() == 0)
1171     return success();
1172 
1173   SmallVector<StringRef, 4> names;
1174   auto *context = result.getContext();
1175 
1176   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1177     auto resultName = parser.getResultName(i);
1178     StringRef nameStr;
1179     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1180       nameStr = resultName.first;
1181 
1182     names.push_back(nameStr);
1183   }
1184 
1185   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1186   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1187   return success();
1188 }
1189 
1190 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1191   // Note that we only need to print the "name" attribute if the asmprinter
1192   // result name disagrees with it.  This can happen in strange cases, e.g.
1193   // when there are conflicts.
1194   bool namesDisagree = getNames().size() != getNumResults();
1195 
1196   SmallString<32> resultNameStr;
1197   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1198     resultNameStr.clear();
1199     llvm::raw_svector_ostream tmpStream(resultNameStr);
1200     p.printOperand(getResult(i), tmpStream);
1201 
1202     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1203     if (!expectedName ||
1204         tmpStream.str().drop_front() != expectedName.getValue()) {
1205       namesDisagree = true;
1206     }
1207   }
1208 
1209   if (namesDisagree)
1210     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1211   else
1212     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1213 }
1214 
1215 // We set the SSA name in the asm syntax to the contents of the name
1216 // attribute.
1217 void StringAttrPrettyNameOp::getAsmResultNames(
1218     function_ref<void(Value, StringRef)> setNameFn) {
1219 
1220   auto value = getNames();
1221   for (size_t i = 0, e = value.size(); i != e; ++i)
1222     if (auto str = value[i].dyn_cast<StringAttr>())
1223       if (!str.getValue().empty())
1224         setNameFn(getResult(i), str.getValue());
1225 }
1226 
1227 void CustomResultsNameOp::getAsmResultNames(
1228     function_ref<void(Value, StringRef)> setNameFn) {
1229   ArrayAttr value = getNames();
1230   for (size_t i = 0, e = value.size(); i != e; ++i)
1231     if (auto str = value[i].dyn_cast<StringAttr>())
1232       if (!str.getValue().empty())
1233         setNameFn(getResult(i), str.getValue());
1234 }
1235 
1236 //===----------------------------------------------------------------------===//
1237 // ResultTypeWithTraitOp
1238 //===----------------------------------------------------------------------===//
1239 
1240 LogicalResult ResultTypeWithTraitOp::verify() {
1241   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1242     return success();
1243   return emitError("result type should have trait 'TestTypeTrait'");
1244 }
1245 
1246 //===----------------------------------------------------------------------===//
1247 // AttrWithTraitOp
1248 //===----------------------------------------------------------------------===//
1249 
1250 LogicalResult AttrWithTraitOp::verify() {
1251   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1252     return success();
1253   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1254 }
1255 
1256 //===----------------------------------------------------------------------===//
1257 // RegionIfOp
1258 //===----------------------------------------------------------------------===//
1259 
1260 void RegionIfOp::print(OpAsmPrinter &p) {
1261   p << " ";
1262   p.printOperands(getOperands());
1263   p << ": " << getOperandTypes();
1264   p.printArrowTypeList(getResultTypes());
1265   p << " then ";
1266   p.printRegion(getThenRegion(),
1267                 /*printEntryBlockArgs=*/true,
1268                 /*printBlockTerminators=*/true);
1269   p << " else ";
1270   p.printRegion(getElseRegion(),
1271                 /*printEntryBlockArgs=*/true,
1272                 /*printBlockTerminators=*/true);
1273   p << " join ";
1274   p.printRegion(getJoinRegion(),
1275                 /*printEntryBlockArgs=*/true,
1276                 /*printBlockTerminators=*/true);
1277 }
1278 
1279 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1280   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1281   SmallVector<Type, 2> operandTypes;
1282 
1283   result.regions.reserve(3);
1284   Region *thenRegion = result.addRegion();
1285   Region *elseRegion = result.addRegion();
1286   Region *joinRegion = result.addRegion();
1287 
1288   // Parse operand, type and arrow type lists.
1289   if (parser.parseOperandList(operandInfos) ||
1290       parser.parseColonTypeList(operandTypes) ||
1291       parser.parseArrowTypeList(result.types))
1292     return failure();
1293 
1294   // Parse all attached regions.
1295   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1296       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1297       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1298     return failure();
1299 
1300   return parser.resolveOperands(operandInfos, operandTypes,
1301                                 parser.getCurrentLocation(), result.operands);
1302 }
1303 
1304 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1305   assert(index && *index < 2 && "invalid region index");
1306   return getOperands();
1307 }
1308 
1309 void RegionIfOp::getSuccessorRegions(
1310     Optional<unsigned> index, ArrayRef<Attribute> operands,
1311     SmallVectorImpl<RegionSuccessor> &regions) {
1312   // We always branch to the join region.
1313   if (index) {
1314     if (index.value() < 2)
1315       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1316     else
1317       regions.push_back(RegionSuccessor(getResults()));
1318     return;
1319   }
1320 
1321   // The then and else regions are the entry regions of this op.
1322   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1323   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1324 }
1325 
1326 void RegionIfOp::getRegionInvocationBounds(
1327     ArrayRef<Attribute> operands,
1328     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1329   // Each region is invoked at most once.
1330   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1331 }
1332 
1333 //===----------------------------------------------------------------------===//
1334 // AnyCondOp
1335 //===----------------------------------------------------------------------===//
1336 
1337 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1338                                     ArrayRef<Attribute> operands,
1339                                     SmallVectorImpl<RegionSuccessor> &regions) {
1340   // The parent op branches into the only region, and the region branches back
1341   // to the parent op.
1342   if (!index)
1343     regions.emplace_back(&getRegion());
1344   else
1345     regions.emplace_back(getResults());
1346 }
1347 
1348 void AnyCondOp::getRegionInvocationBounds(
1349     ArrayRef<Attribute> operands,
1350     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1351   invocationBounds.emplace_back(1, 1);
1352 }
1353 
1354 //===----------------------------------------------------------------------===//
1355 // SingleNoTerminatorCustomAsmOp
1356 //===----------------------------------------------------------------------===//
1357 
1358 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1359                                                  OperationState &state) {
1360   Region *body = state.addRegion();
1361   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1362     return failure();
1363   return success();
1364 }
1365 
1366 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1367   printer.printRegion(
1368       getRegion(), /*printEntryBlockArgs=*/false,
1369       // This op has a single block without terminators. But explicitly mark
1370       // as not printing block terminators for testing.
1371       /*printBlockTerminators=*/false);
1372 }
1373 
1374 //===----------------------------------------------------------------------===//
1375 // TestVerifiersOp
1376 //===----------------------------------------------------------------------===//
1377 
1378 LogicalResult TestVerifiersOp::verify() {
1379   if (!getRegion().hasOneBlock())
1380     return emitOpError("`hasOneBlock` trait hasn't been verified");
1381 
1382   Operation *definingOp = getInput().getDefiningOp();
1383   if (definingOp && failed(mlir::verify(definingOp)))
1384     return emitOpError("operand hasn't been verified");
1385 
1386   emitRemark("success run of verifier");
1387 
1388   return success();
1389 }
1390 
1391 LogicalResult TestVerifiersOp::verifyRegions() {
1392   if (!getRegion().hasOneBlock())
1393     return emitOpError("`hasOneBlock` trait hasn't been verified");
1394 
1395   for (Block &block : getRegion())
1396     for (Operation &op : block)
1397       if (failed(mlir::verify(&op)))
1398         return emitOpError("nested op hasn't been verified");
1399 
1400   emitRemark("success run of region verifier");
1401 
1402   return success();
1403 }
1404 
1405 //===----------------------------------------------------------------------===//
1406 // Test InferIntRangeInterface
1407 //===----------------------------------------------------------------------===//
1408 
1409 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1410                                          SetIntRangeFn setResultRanges) {
1411   setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
1412 }
1413 
1414 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
1415                                           OperationState &result) {
1416   if (parser.parseOptionalAttrDict(result.attributes))
1417     return failure();
1418 
1419   // Parse the input argument
1420   OpAsmParser::Argument argInfo;
1421   argInfo.type = parser.getBuilder().getIndexType();
1422   if (failed(parser.parseArgument(argInfo)))
1423     return failure();
1424 
1425   // Parse the body region, and reuse the operand info as the argument info.
1426   Region *body = result.addRegion();
1427   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
1428 }
1429 
1430 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
1431   p.printOptionalAttrDict((*this)->getAttrs());
1432   p << ' ';
1433   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
1434                         /*omitType=*/true);
1435   p << ' ';
1436   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
1437 }
1438 
1439 void TestWithBoundsRegionOp::inferResultRanges(
1440     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1441   Value arg = getRegion().getArgument(0);
1442   setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
1443 }
1444 
1445 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1446                                         SetIntRangeFn setResultRanges) {
1447   const ConstantIntRanges &range = argRanges[0];
1448   APInt one(range.umin().getBitWidth(), 1);
1449   setResultRanges(getResult(),
1450                   {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
1451                    range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
1452 }
1453 
1454 void TestReflectBoundsOp::inferResultRanges(
1455     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1456   const ConstantIntRanges &range = argRanges[0];
1457   MLIRContext *ctx = getContext();
1458   Builder b(ctx);
1459   setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
1460   setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
1461   setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
1462   setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
1463   setResultRanges(getResult(), range);
1464 }
1465 
1466 #include "TestOpEnums.cpp.inc"
1467 #include "TestOpInterfaces.cpp.inc"
1468 #include "TestTypeInterfaces.cpp.inc"
1469 
1470 #define GET_OP_CLASSES
1471 #include "TestOps.cpp.inc"
1472