1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Verifier.h"
22 #include "mlir/Reducer/ReductionPatternInterface.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/StringExtras.h"
26 #include "llvm/ADT/StringSwitch.h"
27 
28 // Include this before the using namespace lines below to
29 // test that we don't have namespace dependencies.
30 #include "TestOpsDialect.cpp.inc"
31 
32 using namespace mlir;
33 using namespace test;
34 
35 void test::registerTestDialect(DialectRegistry &registry) {
36   registry.insert<TestDialect>();
37 }
38 
39 //===----------------------------------------------------------------------===//
40 // TestDialect Interfaces
41 //===----------------------------------------------------------------------===//
42 
43 namespace {
44 
45 /// Testing the correctness of some traits.
46 static_assert(
47     llvm::is_detected<OpTrait::has_implicit_terminator_t,
48                       SingleBlockImplicitTerminatorOp>::value,
49     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
50 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
51                   SingleBlockImplicitTerminatorOp>::value,
52               "hasSingleBlockImplicitTerminator does not match "
53               "SingleBlockImplicitTerminatorOp");
54 
55 // Test support for interacting with the AsmPrinter.
56 struct TestOpAsmInterface : public OpAsmDialectInterface {
57   using OpAsmDialectInterface::OpAsmDialectInterface;
58 
59   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
60     StringAttr strAttr = attr.dyn_cast<StringAttr>();
61     if (!strAttr)
62       return AliasResult::NoAlias;
63 
64     // Check the contents of the string attribute to see what the test alias
65     // should be named.
66     Optional<StringRef> aliasName =
67         StringSwitch<Optional<StringRef>>(strAttr.getValue())
68             .Case("alias_test:dot_in_name", StringRef("test.alias"))
69             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
70             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
71             .Case("alias_test:sanitize_conflict_a",
72                   StringRef("test_alias_conflict0"))
73             .Case("alias_test:sanitize_conflict_b",
74                   StringRef("test_alias_conflict0_"))
75             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
76             .Default(llvm::None);
77     if (!aliasName)
78       return AliasResult::NoAlias;
79 
80     os << *aliasName;
81     return AliasResult::FinalAlias;
82   }
83 
84   AliasResult getAlias(Type type, raw_ostream &os) const final {
85     if (auto tupleType = type.dyn_cast<TupleType>()) {
86       if (tupleType.size() > 0 &&
87           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
88             return elemType.isa<SimpleAType>();
89           })) {
90         os << "test_tuple";
91         return AliasResult::FinalAlias;
92       }
93     }
94     if (auto intType = type.dyn_cast<TestIntegerType>()) {
95       if (intType.getSignedness() ==
96               TestIntegerType::SignednessSemantics::Unsigned &&
97           intType.getWidth() == 8) {
98         os << "test_ui8";
99         return AliasResult::FinalAlias;
100       }
101     }
102     return AliasResult::NoAlias;
103   }
104 };
105 
106 struct TestDialectFoldInterface : public DialectFoldInterface {
107   using DialectFoldInterface::DialectFoldInterface;
108 
109   /// Registered hook to check if the given region, which is attached to an
110   /// operation that is *not* isolated from above, should be used when
111   /// materializing constants.
112   bool shouldMaterializeInto(Region *region) const final {
113     // If this is a one region operation, then insert into it.
114     return isa<OneRegionOp>(region->getParentOp());
115   }
116 };
117 
118 /// This class defines the interface for handling inlining with standard
119 /// operations.
120 struct TestInlinerInterface : public DialectInlinerInterface {
121   using DialectInlinerInterface::DialectInlinerInterface;
122 
123   //===--------------------------------------------------------------------===//
124   // Analysis Hooks
125   //===--------------------------------------------------------------------===//
126 
127   bool isLegalToInline(Operation *call, Operation *callable,
128                        bool wouldBeCloned) const final {
129     // Don't allow inlining calls that are marked `noinline`.
130     return !call->hasAttr("noinline");
131   }
132   bool isLegalToInline(Region *, Region *, bool,
133                        BlockAndValueMapping &) const final {
134     // Inlining into test dialect regions is legal.
135     return true;
136   }
137   bool isLegalToInline(Operation *, Region *, bool,
138                        BlockAndValueMapping &) const final {
139     return true;
140   }
141 
142   bool shouldAnalyzeRecursively(Operation *op) const final {
143     // Analyze recursively if this is not a functional region operation, it
144     // froms a separate functional scope.
145     return !isa<FunctionalRegionOp>(op);
146   }
147 
148   //===--------------------------------------------------------------------===//
149   // Transformation Hooks
150   //===--------------------------------------------------------------------===//
151 
152   /// Handle the given inlined terminator by replacing it with a new operation
153   /// as necessary.
154   void handleTerminator(Operation *op,
155                         ArrayRef<Value> valuesToRepl) const final {
156     // Only handle "test.return" here.
157     auto returnOp = dyn_cast<TestReturnOp>(op);
158     if (!returnOp)
159       return;
160 
161     // Replace the values directly with the return operands.
162     assert(returnOp.getNumOperands() == valuesToRepl.size());
163     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
164       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
165   }
166 
167   /// Attempt to materialize a conversion for a type mismatch between a call
168   /// from this dialect, and a callable region. This method should generate an
169   /// operation that takes 'input' as the only operand, and produces a single
170   /// result of 'resultType'. If a conversion can not be generated, nullptr
171   /// should be returned.
172   Operation *materializeCallConversion(OpBuilder &builder, Value input,
173                                        Type resultType,
174                                        Location conversionLoc) const final {
175     // Only allow conversion for i16/i32 types.
176     if (!(resultType.isSignlessInteger(16) ||
177           resultType.isSignlessInteger(32)) ||
178         !(input.getType().isSignlessInteger(16) ||
179           input.getType().isSignlessInteger(32)))
180       return nullptr;
181     return builder.create<TestCastOp>(conversionLoc, resultType, input);
182   }
183 
184   void processInlinedCallBlocks(
185       Operation *call,
186       iterator_range<Region::iterator> inlinedBlocks) const final {
187     if (!isa<ConversionCallOp>(call))
188       return;
189 
190     // Set attributed on all ops in the inlined blocks.
191     for (Block &block : inlinedBlocks) {
192       block.walk([&](Operation *op) {
193         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
194       });
195     }
196   }
197 };
198 
199 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
200 public:
201   TestReductionPatternInterface(Dialect *dialect)
202       : DialectReductionPatternInterface(dialect) {}
203 
204   void populateReductionPatterns(RewritePatternSet &patterns) const final {
205     populateTestReductionPatterns(patterns);
206   }
207 };
208 
209 } // namespace
210 
211 //===----------------------------------------------------------------------===//
212 // TestDialect
213 //===----------------------------------------------------------------------===//
214 
215 static void testSideEffectOpGetEffect(
216     Operation *op,
217     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
218 
219 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
220 struct TestOpEffectInterfaceFallback
221     : public TestEffectOpInterface::FallbackModel<
222           TestOpEffectInterfaceFallback> {
223   static bool classof(Operation *op) {
224     bool isSupportedOp =
225         op->getName().getStringRef() == "test.unregistered_side_effect_op";
226     assert(isSupportedOp && "Unexpected dispatch");
227     return isSupportedOp;
228   }
229 
230   void
231   getEffects(Operation *op,
232              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
233                  &effects) const {
234     testSideEffectOpGetEffect(op, effects);
235   }
236 };
237 
238 void TestDialect::initialize() {
239   registerAttributes();
240   registerTypes();
241   addOperations<
242 #define GET_OP_LIST
243 #include "TestOps.cpp.inc"
244       >();
245   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
246                 TestInlinerInterface, TestReductionPatternInterface>();
247   allowUnknownOperations();
248 
249   // Instantiate our fallback op interface that we'll use on specific
250   // unregistered op.
251   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
252 }
253 TestDialect::~TestDialect() {
254   delete static_cast<TestOpEffectInterfaceFallback *>(
255       fallbackEffectOpInterfaces);
256 }
257 
258 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
259                                             Type type, Location loc) {
260   return builder.create<TestOpConstant>(loc, type, value);
261 }
262 
263 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
264     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
265     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
266     ::mlir::RegionRange regions,
267     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
268   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
269   return ::mlir::success();
270 }
271 
272 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
273                                                OperationName opName) {
274   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
275       typeID == TypeID::get<TestEffectOpInterface>())
276     return fallbackEffectOpInterfaces;
277   return nullptr;
278 }
279 
280 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
281                                                     NamedAttribute namedAttr) {
282   if (namedAttr.getName() == "test.invalid_attr")
283     return op->emitError() << "invalid to use 'test.invalid_attr'";
284   return success();
285 }
286 
287 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
288                                                     unsigned regionIndex,
289                                                     unsigned argIndex,
290                                                     NamedAttribute namedAttr) {
291   if (namedAttr.getName() == "test.invalid_attr")
292     return op->emitError() << "invalid to use 'test.invalid_attr'";
293   return success();
294 }
295 
296 LogicalResult
297 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
298                                          unsigned resultIndex,
299                                          NamedAttribute namedAttr) {
300   if (namedAttr.getName() == "test.invalid_attr")
301     return op->emitError() << "invalid to use 'test.invalid_attr'";
302   return success();
303 }
304 
305 Optional<Dialect::ParseOpHook>
306 TestDialect::getParseOperationHook(StringRef opName) const {
307   if (opName == "test.dialect_custom_printer") {
308     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
309       return parser.parseKeyword("custom_format");
310     }};
311   }
312   if (opName == "test.dialect_custom_format_fallback") {
313     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
314       return parser.parseKeyword("custom_format_fallback");
315     }};
316   }
317   return None;
318 }
319 
320 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
321 TestDialect::getOperationPrinter(Operation *op) const {
322   StringRef opName = op->getName().getStringRef();
323   if (opName == "test.dialect_custom_printer") {
324     return [](Operation *op, OpAsmPrinter &printer) {
325       printer.getStream() << " custom_format";
326     };
327   }
328   if (opName == "test.dialect_custom_format_fallback") {
329     return [](Operation *op, OpAsmPrinter &printer) {
330       printer.getStream() << " custom_format_fallback";
331     };
332   }
333   return {};
334 }
335 
336 //===----------------------------------------------------------------------===//
337 // TestBranchOp
338 //===----------------------------------------------------------------------===//
339 
340 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
341   assert(index == 0 && "invalid successor index");
342   return SuccessorOperands(getTargetOperandsMutable());
343 }
344 
345 //===----------------------------------------------------------------------===//
346 // TestProducingBranchOp
347 //===----------------------------------------------------------------------===//
348 
349 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
350   assert(index <= 1 && "invalid successor index");
351   if (index == 1)
352     return SuccessorOperands(getFirstOperandsMutable());
353   return SuccessorOperands(getSecondOperandsMutable());
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // TestProducingBranchOp
358 //===----------------------------------------------------------------------===//
359 
360 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
361   assert(index <= 1 && "invalid successor index");
362   if (index == 0)
363     return SuccessorOperands(0, getSuccessOperandsMutable());
364   return SuccessorOperands(1, getErrorOperandsMutable());
365 }
366 
367 //===----------------------------------------------------------------------===//
368 // TestDialectCanonicalizerOp
369 //===----------------------------------------------------------------------===//
370 
371 static LogicalResult
372 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
373                                PatternRewriter &rewriter) {
374   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
375       op, rewriter.getI32IntegerAttr(42));
376   return success();
377 }
378 
379 void TestDialect::getCanonicalizationPatterns(
380     RewritePatternSet &results) const {
381   results.add(&dialectCanonicalizationPattern);
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // TestCallOp
386 //===----------------------------------------------------------------------===//
387 
388 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
389   // Check that the callee attribute was specified.
390   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
391   if (!fnAttr)
392     return emitOpError("requires a 'callee' symbol reference attribute");
393   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
394     return emitOpError() << "'" << fnAttr.getValue()
395                          << "' does not reference a valid function";
396   return success();
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // TestFoldToCallOp
401 //===----------------------------------------------------------------------===//
402 
403 namespace {
404 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
405   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
406 
407   LogicalResult matchAndRewrite(FoldToCallOp op,
408                                 PatternRewriter &rewriter) const override {
409     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
410                                               op.getCalleeAttr(), ValueRange());
411     return success();
412   }
413 };
414 } // namespace
415 
416 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
417                                                MLIRContext *context) {
418   results.add<FoldToCallOpPattern>(context);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // Test Format* operations
423 //===----------------------------------------------------------------------===//
424 
425 //===----------------------------------------------------------------------===//
426 // Parsing
427 
428 static ParseResult parseCustomOptionalOperand(
429     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
430   if (succeeded(parser.parseOptionalLParen())) {
431     optOperand.emplace();
432     if (parser.parseOperand(*optOperand) || parser.parseRParen())
433       return failure();
434   }
435   return success();
436 }
437 
438 static ParseResult parseCustomDirectiveOperands(
439     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
440     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
441     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
442   if (parser.parseOperand(operand))
443     return failure();
444   if (succeeded(parser.parseOptionalComma())) {
445     optOperand.emplace();
446     if (parser.parseOperand(*optOperand))
447       return failure();
448   }
449   if (parser.parseArrow() || parser.parseLParen() ||
450       parser.parseOperandList(varOperands) || parser.parseRParen())
451     return failure();
452   return success();
453 }
454 static ParseResult
455 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
456                             Type &optOperandType,
457                             SmallVectorImpl<Type> &varOperandTypes) {
458   if (parser.parseColon())
459     return failure();
460 
461   if (parser.parseType(operandType))
462     return failure();
463   if (succeeded(parser.parseOptionalComma())) {
464     if (parser.parseType(optOperandType))
465       return failure();
466   }
467   if (parser.parseArrow() || parser.parseLParen() ||
468       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
469     return failure();
470   return success();
471 }
472 static ParseResult
473 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
474                                  Type optOperandType,
475                                  const SmallVectorImpl<Type> &varOperandTypes) {
476   if (parser.parseKeyword("type_refs_capture"))
477     return failure();
478 
479   Type operandType2, optOperandType2;
480   SmallVector<Type, 1> varOperandTypes2;
481   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
482                                   varOperandTypes2))
483     return failure();
484 
485   if (operandType != operandType2 || optOperandType != optOperandType2 ||
486       varOperandTypes != varOperandTypes2)
487     return failure();
488 
489   return success();
490 }
491 static ParseResult parseCustomDirectiveOperandsAndTypes(
492     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
493     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
494     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
495     Type &operandType, Type &optOperandType,
496     SmallVectorImpl<Type> &varOperandTypes) {
497   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
498       parseCustomDirectiveResults(parser, operandType, optOperandType,
499                                   varOperandTypes))
500     return failure();
501   return success();
502 }
503 static ParseResult parseCustomDirectiveRegions(
504     OpAsmParser &parser, Region &region,
505     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
506   if (parser.parseRegion(region))
507     return failure();
508   if (failed(parser.parseOptionalComma()))
509     return success();
510   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
511   if (parser.parseRegion(*varRegion))
512     return failure();
513   varRegions.emplace_back(std::move(varRegion));
514   return success();
515 }
516 static ParseResult
517 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
518                                SmallVectorImpl<Block *> &varSuccessors) {
519   if (parser.parseSuccessor(successor))
520     return failure();
521   if (failed(parser.parseOptionalComma()))
522     return success();
523   Block *varSuccessor;
524   if (parser.parseSuccessor(varSuccessor))
525     return failure();
526   varSuccessors.append(2, varSuccessor);
527   return success();
528 }
529 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
530                                                   IntegerAttr &attr,
531                                                   IntegerAttr &optAttr) {
532   if (parser.parseAttribute(attr))
533     return failure();
534   if (succeeded(parser.parseOptionalComma())) {
535     if (parser.parseAttribute(optAttr))
536       return failure();
537   }
538   return success();
539 }
540 
541 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
542                                                 NamedAttrList &attrs) {
543   return parser.parseOptionalAttrDict(attrs);
544 }
545 static ParseResult parseCustomDirectiveOptionalOperandRef(
546     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
547   int64_t operandCount = 0;
548   if (parser.parseInteger(operandCount))
549     return failure();
550   bool expectedOptionalOperand = operandCount == 0;
551   return success(expectedOptionalOperand != optOperand.hasValue());
552 }
553 
554 //===----------------------------------------------------------------------===//
555 // Printing
556 
557 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
558                                        Value optOperand) {
559   if (optOperand)
560     printer << "(" << optOperand << ") ";
561 }
562 
563 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
564                                          Value operand, Value optOperand,
565                                          OperandRange varOperands) {
566   printer << operand;
567   if (optOperand)
568     printer << ", " << optOperand;
569   printer << " -> (" << varOperands << ")";
570 }
571 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
572                                         Type operandType, Type optOperandType,
573                                         TypeRange varOperandTypes) {
574   printer << " : " << operandType;
575   if (optOperandType)
576     printer << ", " << optOperandType;
577   printer << " -> (" << varOperandTypes << ")";
578 }
579 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
580                                              Operation *op, Type operandType,
581                                              Type optOperandType,
582                                              TypeRange varOperandTypes) {
583   printer << " type_refs_capture ";
584   printCustomDirectiveResults(printer, op, operandType, optOperandType,
585                               varOperandTypes);
586 }
587 static void printCustomDirectiveOperandsAndTypes(
588     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
589     OperandRange varOperands, Type operandType, Type optOperandType,
590     TypeRange varOperandTypes) {
591   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
592   printCustomDirectiveResults(printer, op, operandType, optOperandType,
593                               varOperandTypes);
594 }
595 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
596                                         Region &region,
597                                         MutableArrayRef<Region> varRegions) {
598   printer.printRegion(region);
599   if (!varRegions.empty()) {
600     printer << ", ";
601     for (Region &region : varRegions)
602       printer.printRegion(region);
603   }
604 }
605 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
606                                            Block *successor,
607                                            SuccessorRange varSuccessors) {
608   printer << successor;
609   if (!varSuccessors.empty())
610     printer << ", " << varSuccessors.front();
611 }
612 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
613                                            Attribute attribute,
614                                            Attribute optAttribute) {
615   printer << attribute;
616   if (optAttribute)
617     printer << ", " << optAttribute;
618 }
619 
620 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
621                                          DictionaryAttr attrs) {
622   printer.printOptionalAttrDict(attrs.getValue());
623 }
624 
625 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
626                                                    Operation *op,
627                                                    Value optOperand) {
628   printer << (optOperand ? "1" : "0");
629 }
630 
631 //===----------------------------------------------------------------------===//
632 // Test IsolatedRegionOp - parse passthrough region arguments.
633 //===----------------------------------------------------------------------===//
634 
635 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
636                                     OperationState &result) {
637   OpAsmParser::UnresolvedOperand argInfo;
638   Type argType = parser.getBuilder().getIndexType();
639 
640   // Parse the input operand.
641   if (parser.parseOperand(argInfo) ||
642       parser.resolveOperand(argInfo, argType, result.operands))
643     return failure();
644 
645   // Parse the body region, and reuse the operand info as the argument info.
646   Region *body = result.addRegion();
647   return parser.parseRegion(*body, argInfo, argType,
648                             /*enableNameShadowing=*/true);
649 }
650 
651 void IsolatedRegionOp::print(OpAsmPrinter &p) {
652   p << "test.isolated_region ";
653   p.printOperand(getOperand());
654   p.shadowRegionArgs(getRegion(), getOperand());
655   p << ' ';
656   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // Test SSACFGRegionOp
661 //===----------------------------------------------------------------------===//
662 
663 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
664   return RegionKind::SSACFG;
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // Test GraphRegionOp
669 //===----------------------------------------------------------------------===//
670 
671 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
672   // Parse the body region, and reuse the operand info as the argument info.
673   Region *body = result.addRegion();
674   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
675 }
676 
677 void GraphRegionOp::print(OpAsmPrinter &p) {
678   p << "test.graph_region ";
679   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
680 }
681 
682 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
683   return RegionKind::Graph;
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Test AffineScopeOp
688 //===----------------------------------------------------------------------===//
689 
690 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
691   // Parse the body region, and reuse the operand info as the argument info.
692   Region *body = result.addRegion();
693   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
694 }
695 
696 void AffineScopeOp::print(OpAsmPrinter &p) {
697   p << "test.affine_scope ";
698   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // Test parser.
703 //===----------------------------------------------------------------------===//
704 
705 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
706                                          OperationState &result) {
707   if (parser.parseOptionalColon())
708     return success();
709   uint64_t numResults;
710   if (parser.parseInteger(numResults))
711     return failure();
712 
713   IndexType type = parser.getBuilder().getIndexType();
714   for (unsigned i = 0; i < numResults; ++i)
715     result.addTypes(type);
716   return success();
717 }
718 
719 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
720   if (unsigned numResults = getNumResults())
721     p << " : " << numResults;
722 }
723 
724 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
725                                          OperationState &result) {
726   StringRef keyword;
727   if (parser.parseKeyword(&keyword))
728     return failure();
729   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
730   return success();
731 }
732 
733 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
734 
735 //===----------------------------------------------------------------------===//
736 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
737 
738 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
739                                     OperationState &result) {
740   if (parser.parseKeyword("wraps"))
741     return failure();
742 
743   // Parse the wrapped op in a region
744   Region &body = *result.addRegion();
745   body.push_back(new Block);
746   Block &block = body.back();
747   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
748   if (!wrappedOp)
749     return failure();
750 
751   // Create a return terminator in the inner region, pass as operand to the
752   // terminator the returned values from the wrapped operation.
753   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
754   OpBuilder builder(parser.getContext());
755   builder.setInsertionPointToEnd(&block);
756   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
757 
758   // Get the results type for the wrapping op from the terminator operands.
759   Operation &returnOp = body.back().back();
760   result.types.append(returnOp.operand_type_begin(),
761                       returnOp.operand_type_end());
762 
763   // Use the location of the wrapped op for the "test.wrapping_region" op.
764   result.location = wrappedOp->getLoc();
765 
766   return success();
767 }
768 
769 void WrappingRegionOp::print(OpAsmPrinter &p) {
770   p << " wraps ";
771   p.printGenericOp(&getRegion().front().front());
772 }
773 
774 //===----------------------------------------------------------------------===//
775 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
776 //   parseGenericOperationAfterOpName
777 //   parseCustomOperationName
778 //===----------------------------------------------------------------------===//
779 
780 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
781                                          OperationState &result) {
782 
783   SMLoc loc = parser.getCurrentLocation();
784   Location currLocation = parser.getEncodedSourceLoc(loc);
785 
786   // Parse the operands.
787   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
788   if (parser.parseOperandList(operands))
789     return failure();
790 
791   // Check if we are parsing the pretty-printed version
792   //  test.pretty_printed_region start <inner-op> end : <functional-type>
793   // Else fallback to parsing the "non pretty-printed" version.
794   if (!succeeded(parser.parseOptionalKeyword("start")))
795     return parser.parseGenericOperationAfterOpName(
796         result, llvm::makeArrayRef(operands));
797 
798   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
799   if (failed(parseOpNameInfo))
800     return failure();
801 
802   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
803 
804   FunctionType opFntype;
805   Optional<Location> explicitLoc;
806   if (parser.parseKeyword("end") || parser.parseColon() ||
807       parser.parseType(opFntype) ||
808       parser.parseOptionalLocationSpecifier(explicitLoc))
809     return failure();
810 
811   // If location of the op is explicitly provided, then use it; Else use
812   // the parser's current location.
813   Location opLoc = explicitLoc.getValueOr(currLocation);
814 
815   // Derive the SSA-values for op's operands.
816   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
817                              result.operands))
818     return failure();
819 
820   // Add a region for op.
821   Region &region = *result.addRegion();
822 
823   // Create a basic-block inside op's region.
824   Block &block = region.emplaceBlock();
825 
826   // Create and insert an "inner-op" operation in the block.
827   // Just for testing purposes, we can assume that inner op is a binary op with
828   // result and operand types all same as the test-op's first operand.
829   Type innerOpType = opFntype.getInput(0);
830   Value lhs = block.addArgument(innerOpType, opLoc);
831   Value rhs = block.addArgument(innerOpType, opLoc);
832 
833   OpBuilder builder(parser.getBuilder().getContext());
834   builder.setInsertionPointToStart(&block);
835 
836   Operation *innerOp =
837       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
838 
839   // Insert a return statement in the block returning the inner-op's result.
840   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
841 
842   // Populate the op operation-state with result-type and location.
843   result.addTypes(opFntype.getResults());
844   result.location = innerOp->getLoc();
845 
846   return success();
847 }
848 
849 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
850   p << ' ';
851   p.printOperands(getOperands());
852 
853   Operation &innerOp = getRegion().front().front();
854   // Assuming that region has a single non-terminator inner-op, if the inner-op
855   // meets some criteria (which in this case is a simple one  based on the name
856   // of inner-op), then we can print the entire region in a succinct way.
857   // Here we assume that the prototype of "special.op" can be trivially derived
858   // while parsing it back.
859   if (innerOp.getName().getStringRef().equals("special.op")) {
860     p << " start special.op end";
861   } else {
862     p << " (";
863     p.printRegion(getRegion());
864     p << ")";
865   }
866 
867   p << " : ";
868   p.printFunctionalType(*this);
869 }
870 
871 //===----------------------------------------------------------------------===//
872 // Test PolyForOp - parse list of region arguments.
873 //===----------------------------------------------------------------------===//
874 
875 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
876   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
877   // Parse list of region arguments without a delimiter.
878   if (parser.parseRegionArgumentList(ivsInfo))
879     return failure();
880 
881   // Parse the body region.
882   Region *body = result.addRegion();
883   auto &builder = parser.getBuilder();
884   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
885   return parser.parseRegion(*body, ivsInfo, argTypes);
886 }
887 
888 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
889 
890 void PolyForOp::getAsmBlockArgumentNames(Region &region,
891                                          OpAsmSetValueNameFn setNameFn) {
892   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
893   if (!arrayAttr)
894     return;
895   auto args = getRegion().front().getArguments();
896   auto e = std::min(arrayAttr.size(), args.size());
897   for (unsigned i = 0; i < e; ++i) {
898     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
899       setNameFn(args[i], strAttr.getValue());
900   }
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // Test removing op with inner ops.
905 //===----------------------------------------------------------------------===//
906 
907 namespace {
908 struct TestRemoveOpWithInnerOps
909     : public OpRewritePattern<TestOpWithRegionPattern> {
910   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
911 
912   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
913 
914   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
915                                 PatternRewriter &rewriter) const override {
916     rewriter.eraseOp(op);
917     return success();
918   }
919 };
920 } // namespace
921 
922 void TestOpWithRegionPattern::getCanonicalizationPatterns(
923     RewritePatternSet &results, MLIRContext *context) {
924   results.add<TestRemoveOpWithInnerOps>(context);
925 }
926 
927 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
928   return getOperand();
929 }
930 
931 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
932   return getValue();
933 }
934 
935 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
936     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
937   for (Value input : this->getOperands()) {
938     results.push_back(input);
939   }
940   return success();
941 }
942 
943 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
944   assert(operands.size() == 1);
945   if (operands.front()) {
946     (*this)->setAttr("attr", operands.front());
947     return getResult();
948   }
949   return {};
950 }
951 
952 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
953   return getOperand();
954 }
955 
956 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
957     MLIRContext *, Optional<Location> location, ValueRange operands,
958     DictionaryAttr attributes, RegionRange regions,
959     SmallVectorImpl<Type> &inferredReturnTypes) {
960   if (operands[0].getType() != operands[1].getType()) {
961     return emitOptionalError(location, "operand type mismatch ",
962                              operands[0].getType(), " vs ",
963                              operands[1].getType());
964   }
965   inferredReturnTypes.assign({operands[0].getType()});
966   return success();
967 }
968 
969 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
970     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
971     DictionaryAttr attributes, RegionRange regions,
972     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
973   // Create return type consisting of the last element of the first operand.
974   auto operandType = operands.front().getType();
975   auto sval = operandType.dyn_cast<ShapedType>();
976   if (!sval) {
977     return emitOptionalError(location, "only shaped type operands allowed");
978   }
979   int64_t dim =
980       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
981   auto type = IntegerType::get(context, 17);
982   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
983   return success();
984 }
985 
986 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
987     OpBuilder &builder, ValueRange operands,
988     llvm::SmallVectorImpl<Value> &shapes) {
989   shapes = SmallVector<Value, 1>{
990       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
991   return success();
992 }
993 
994 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
995     OpBuilder &builder, ValueRange operands,
996     llvm::SmallVectorImpl<Value> &shapes) {
997   Location loc = getLoc();
998   shapes.reserve(operands.size());
999   for (Value operand : llvm::reverse(operands)) {
1000     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1001     auto currShape = llvm::to_vector<4>(
1002         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1003           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1004         }));
1005     shapes.push_back(builder.create<tensor::FromElementsOp>(
1006         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1007         currShape));
1008   }
1009   return success();
1010 }
1011 
1012 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1013     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1014   Location loc = getLoc();
1015   shapes.reserve(getNumOperands());
1016   for (Value operand : llvm::reverse(getOperands())) {
1017     auto currShape = llvm::to_vector<4>(llvm::map_range(
1018         llvm::seq<int64_t>(
1019             0, operand.getType().cast<RankedTensorType>().getRank()),
1020         [&](int64_t dim) -> Value {
1021           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1022         }));
1023     shapes.emplace_back(std::move(currShape));
1024   }
1025   return success();
1026 }
1027 
1028 //===----------------------------------------------------------------------===//
1029 // Test SideEffect interfaces
1030 //===----------------------------------------------------------------------===//
1031 
1032 namespace {
1033 /// A test resource for side effects.
1034 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1035   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1036 
1037   StringRef getName() final { return "<Test>"; }
1038 };
1039 } // namespace
1040 
1041 static void testSideEffectOpGetEffect(
1042     Operation *op,
1043     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1044         &effects) {
1045   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1046   if (!effectsAttr)
1047     return;
1048 
1049   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1050 }
1051 
1052 void SideEffectOp::getEffects(
1053     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1054   // Check for an effects attribute on the op instance.
1055   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1056   if (!effectsAttr)
1057     return;
1058 
1059   // If there is one, it is an array of dictionary attributes that hold
1060   // information on the effects of this operation.
1061   for (Attribute element : effectsAttr) {
1062     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1063 
1064     // Get the specific memory effect.
1065     MemoryEffects::Effect *effect =
1066         StringSwitch<MemoryEffects::Effect *>(
1067             effectElement.get("effect").cast<StringAttr>().getValue())
1068             .Case("allocate", MemoryEffects::Allocate::get())
1069             .Case("free", MemoryEffects::Free::get())
1070             .Case("read", MemoryEffects::Read::get())
1071             .Case("write", MemoryEffects::Write::get());
1072 
1073     // Check for a non-default resource to use.
1074     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1075     if (effectElement.get("test_resource"))
1076       resource = TestResource::get();
1077 
1078     // Check for a result to affect.
1079     if (effectElement.get("on_result"))
1080       effects.emplace_back(effect, getResult(), resource);
1081     else if (Attribute ref = effectElement.get("on_reference"))
1082       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1083     else
1084       effects.emplace_back(effect, resource);
1085   }
1086 }
1087 
1088 void SideEffectOp::getEffects(
1089     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1090   testSideEffectOpGetEffect(getOperation(), effects);
1091 }
1092 
1093 //===----------------------------------------------------------------------===//
1094 // StringAttrPrettyNameOp
1095 //===----------------------------------------------------------------------===//
1096 
1097 // This op has fancy handling of its SSA result name.
1098 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1099                                           OperationState &result) {
1100   // Add the result types.
1101   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1102     result.addTypes(parser.getBuilder().getIntegerType(32));
1103 
1104   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1105     return failure();
1106 
1107   // If the attribute dictionary contains no 'names' attribute, infer it from
1108   // the SSA name (if specified).
1109   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1110     return attr.getName() == "names";
1111   });
1112 
1113   // If there was no name specified, check to see if there was a useful name
1114   // specified in the asm file.
1115   if (hadNames || parser.getNumResults() == 0)
1116     return success();
1117 
1118   SmallVector<StringRef, 4> names;
1119   auto *context = result.getContext();
1120 
1121   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1122     auto resultName = parser.getResultName(i);
1123     StringRef nameStr;
1124     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1125       nameStr = resultName.first;
1126 
1127     names.push_back(nameStr);
1128   }
1129 
1130   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1131   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1132   return success();
1133 }
1134 
1135 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1136   // Note that we only need to print the "name" attribute if the asmprinter
1137   // result name disagrees with it.  This can happen in strange cases, e.g.
1138   // when there are conflicts.
1139   bool namesDisagree = getNames().size() != getNumResults();
1140 
1141   SmallString<32> resultNameStr;
1142   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1143     resultNameStr.clear();
1144     llvm::raw_svector_ostream tmpStream(resultNameStr);
1145     p.printOperand(getResult(i), tmpStream);
1146 
1147     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1148     if (!expectedName ||
1149         tmpStream.str().drop_front() != expectedName.getValue()) {
1150       namesDisagree = true;
1151     }
1152   }
1153 
1154   if (namesDisagree)
1155     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1156   else
1157     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1158 }
1159 
1160 // We set the SSA name in the asm syntax to the contents of the name
1161 // attribute.
1162 void StringAttrPrettyNameOp::getAsmResultNames(
1163     function_ref<void(Value, StringRef)> setNameFn) {
1164 
1165   auto value = getNames();
1166   for (size_t i = 0, e = value.size(); i != e; ++i)
1167     if (auto str = value[i].dyn_cast<StringAttr>())
1168       if (!str.getValue().empty())
1169         setNameFn(getResult(i), str.getValue());
1170 }
1171 
1172 void CustomResultsNameOp::getAsmResultNames(
1173     function_ref<void(Value, StringRef)> setNameFn) {
1174   ArrayAttr value = getNames();
1175   for (size_t i = 0, e = value.size(); i != e; ++i)
1176     if (auto str = value[i].dyn_cast<StringAttr>())
1177       if (!str.getValue().empty())
1178         setNameFn(getResult(i), str.getValue());
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // ResultTypeWithTraitOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 LogicalResult ResultTypeWithTraitOp::verify() {
1186   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1187     return success();
1188   return emitError("result type should have trait 'TestTypeTrait'");
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // AttrWithTraitOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 LogicalResult AttrWithTraitOp::verify() {
1196   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1197     return success();
1198   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // RegionIfOp
1203 //===----------------------------------------------------------------------===//
1204 
1205 void RegionIfOp::print(OpAsmPrinter &p) {
1206   p << " ";
1207   p.printOperands(getOperands());
1208   p << ": " << getOperandTypes();
1209   p.printArrowTypeList(getResultTypes());
1210   p << " then ";
1211   p.printRegion(getThenRegion(),
1212                 /*printEntryBlockArgs=*/true,
1213                 /*printBlockTerminators=*/true);
1214   p << " else ";
1215   p.printRegion(getElseRegion(),
1216                 /*printEntryBlockArgs=*/true,
1217                 /*printBlockTerminators=*/true);
1218   p << " join ";
1219   p.printRegion(getJoinRegion(),
1220                 /*printEntryBlockArgs=*/true,
1221                 /*printBlockTerminators=*/true);
1222 }
1223 
1224 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1225   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1226   SmallVector<Type, 2> operandTypes;
1227 
1228   result.regions.reserve(3);
1229   Region *thenRegion = result.addRegion();
1230   Region *elseRegion = result.addRegion();
1231   Region *joinRegion = result.addRegion();
1232 
1233   // Parse operand, type and arrow type lists.
1234   if (parser.parseOperandList(operandInfos) ||
1235       parser.parseColonTypeList(operandTypes) ||
1236       parser.parseArrowTypeList(result.types))
1237     return failure();
1238 
1239   // Parse all attached regions.
1240   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1241       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1242       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1243     return failure();
1244 
1245   return parser.resolveOperands(operandInfos, operandTypes,
1246                                 parser.getCurrentLocation(), result.operands);
1247 }
1248 
1249 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1250   assert(index < 2 && "invalid region index");
1251   return getOperands();
1252 }
1253 
1254 void RegionIfOp::getSuccessorRegions(
1255     Optional<unsigned> index, ArrayRef<Attribute> operands,
1256     SmallVectorImpl<RegionSuccessor> &regions) {
1257   // We always branch to the join region.
1258   if (index.hasValue()) {
1259     if (index.getValue() < 2)
1260       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1261     else
1262       regions.push_back(RegionSuccessor(getResults()));
1263     return;
1264   }
1265 
1266   // The then and else regions are the entry regions of this op.
1267   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1268   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1269 }
1270 
1271 void RegionIfOp::getRegionInvocationBounds(
1272     ArrayRef<Attribute> operands,
1273     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1274   // Each region is invoked at most once.
1275   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1276 }
1277 
1278 //===----------------------------------------------------------------------===//
1279 // AnyCondOp
1280 //===----------------------------------------------------------------------===//
1281 
1282 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1283                                     ArrayRef<Attribute> operands,
1284                                     SmallVectorImpl<RegionSuccessor> &regions) {
1285   // The parent op branches into the only region, and the region branches back
1286   // to the parent op.
1287   if (index)
1288     regions.emplace_back(&getRegion());
1289   else
1290     regions.emplace_back(getResults());
1291 }
1292 
1293 void AnyCondOp::getRegionInvocationBounds(
1294     ArrayRef<Attribute> operands,
1295     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1296   invocationBounds.emplace_back(1, 1);
1297 }
1298 
1299 //===----------------------------------------------------------------------===//
1300 // SingleNoTerminatorCustomAsmOp
1301 //===----------------------------------------------------------------------===//
1302 
1303 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1304                                                  OperationState &state) {
1305   Region *body = state.addRegion();
1306   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1307     return failure();
1308   return success();
1309 }
1310 
1311 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1312   printer.printRegion(
1313       getRegion(), /*printEntryBlockArgs=*/false,
1314       // This op has a single block without terminators. But explicitly mark
1315       // as not printing block terminators for testing.
1316       /*printBlockTerminators=*/false);
1317 }
1318 
1319 //===----------------------------------------------------------------------===//
1320 // TestVerifiersOp
1321 //===----------------------------------------------------------------------===//
1322 
1323 LogicalResult TestVerifiersOp::verify() {
1324   if (!getRegion().hasOneBlock())
1325     return emitOpError("`hasOneBlock` trait hasn't been verified");
1326 
1327   Operation *definingOp = getInput().getDefiningOp();
1328   if (definingOp && failed(mlir::verify(definingOp)))
1329     return emitOpError("operand hasn't been verified");
1330 
1331   emitRemark("success run of verifier");
1332 
1333   return success();
1334 }
1335 
1336 LogicalResult TestVerifiersOp::verifyRegions() {
1337   if (!getRegion().hasOneBlock())
1338     return emitOpError("`hasOneBlock` trait hasn't been verified");
1339 
1340   for (Block &block : getRegion())
1341     for (Operation &op : block)
1342       if (failed(mlir::verify(&op)))
1343         return emitOpError("nested op hasn't been verified");
1344 
1345   emitRemark("success run of region verifier");
1346 
1347   return success();
1348 }
1349 
1350 #include "TestOpEnums.cpp.inc"
1351 #include "TestOpInterfaces.cpp.inc"
1352 #include "TestOpStructs.cpp.inc"
1353 #include "TestTypeInterfaces.cpp.inc"
1354 
1355 #define GET_OP_CLASSES
1356 #include "TestOps.cpp.inc"
1357