1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Reducer/ReductionPatternInterface.h"
22 #include "mlir/Transforms/FoldUtils.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/StringSwitch.h"
25 
26 // Include this before the using namespace lines below to
27 // test that we don't have namespace dependencies.
28 #include "TestOpsDialect.cpp.inc"
29 
30 using namespace mlir;
31 using namespace test;
32 
33 void test::registerTestDialect(DialectRegistry &registry) {
34   registry.insert<TestDialect>();
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // TestDialect Interfaces
39 //===----------------------------------------------------------------------===//
40 
41 namespace {
42 
43 /// Testing the correctness of some traits.
44 static_assert(
45     llvm::is_detected<OpTrait::has_implicit_terminator_t,
46                       SingleBlockImplicitTerminatorOp>::value,
47     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
48 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
49                   SingleBlockImplicitTerminatorOp>::value,
50               "hasSingleBlockImplicitTerminator does not match "
51               "SingleBlockImplicitTerminatorOp");
52 
53 // Test support for interacting with the AsmPrinter.
54 struct TestOpAsmInterface : public OpAsmDialectInterface {
55   using OpAsmDialectInterface::OpAsmDialectInterface;
56 
57   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
58     StringAttr strAttr = attr.dyn_cast<StringAttr>();
59     if (!strAttr)
60       return AliasResult::NoAlias;
61 
62     // Check the contents of the string attribute to see what the test alias
63     // should be named.
64     Optional<StringRef> aliasName =
65         StringSwitch<Optional<StringRef>>(strAttr.getValue())
66             .Case("alias_test:dot_in_name", StringRef("test.alias"))
67             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
68             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
69             .Case("alias_test:sanitize_conflict_a",
70                   StringRef("test_alias_conflict0"))
71             .Case("alias_test:sanitize_conflict_b",
72                   StringRef("test_alias_conflict0_"))
73             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
74             .Default(llvm::None);
75     if (!aliasName)
76       return AliasResult::NoAlias;
77 
78     os << *aliasName;
79     return AliasResult::FinalAlias;
80   }
81 
82   AliasResult getAlias(Type type, raw_ostream &os) const final {
83     if (auto tupleType = type.dyn_cast<TupleType>()) {
84       if (tupleType.size() > 0 &&
85           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
86             return elemType.isa<SimpleAType>();
87           })) {
88         os << "test_tuple";
89         return AliasResult::FinalAlias;
90       }
91     }
92     if (auto intType = type.dyn_cast<TestIntegerType>()) {
93       if (intType.getSignedness() ==
94               TestIntegerType::SignednessSemantics::Unsigned &&
95           intType.getWidth() == 8) {
96         os << "test_ui8";
97         return AliasResult::FinalAlias;
98       }
99     }
100     return AliasResult::NoAlias;
101   }
102 };
103 
104 struct TestDialectFoldInterface : public DialectFoldInterface {
105   using DialectFoldInterface::DialectFoldInterface;
106 
107   /// Registered hook to check if the given region, which is attached to an
108   /// operation that is *not* isolated from above, should be used when
109   /// materializing constants.
110   bool shouldMaterializeInto(Region *region) const final {
111     // If this is a one region operation, then insert into it.
112     return isa<OneRegionOp>(region->getParentOp());
113   }
114 };
115 
116 /// This class defines the interface for handling inlining with standard
117 /// operations.
118 struct TestInlinerInterface : public DialectInlinerInterface {
119   using DialectInlinerInterface::DialectInlinerInterface;
120 
121   //===--------------------------------------------------------------------===//
122   // Analysis Hooks
123   //===--------------------------------------------------------------------===//
124 
125   bool isLegalToInline(Operation *call, Operation *callable,
126                        bool wouldBeCloned) const final {
127     // Don't allow inlining calls that are marked `noinline`.
128     return !call->hasAttr("noinline");
129   }
130   bool isLegalToInline(Region *, Region *, bool,
131                        BlockAndValueMapping &) const final {
132     // Inlining into test dialect regions is legal.
133     return true;
134   }
135   bool isLegalToInline(Operation *, Region *, bool,
136                        BlockAndValueMapping &) const final {
137     return true;
138   }
139 
140   bool shouldAnalyzeRecursively(Operation *op) const final {
141     // Analyze recursively if this is not a functional region operation, it
142     // froms a separate functional scope.
143     return !isa<FunctionalRegionOp>(op);
144   }
145 
146   //===--------------------------------------------------------------------===//
147   // Transformation Hooks
148   //===--------------------------------------------------------------------===//
149 
150   /// Handle the given inlined terminator by replacing it with a new operation
151   /// as necessary.
152   void handleTerminator(Operation *op,
153                         ArrayRef<Value> valuesToRepl) const final {
154     // Only handle "test.return" here.
155     auto returnOp = dyn_cast<TestReturnOp>(op);
156     if (!returnOp)
157       return;
158 
159     // Replace the values directly with the return operands.
160     assert(returnOp.getNumOperands() == valuesToRepl.size());
161     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
162       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
163   }
164 
165   /// Attempt to materialize a conversion for a type mismatch between a call
166   /// from this dialect, and a callable region. This method should generate an
167   /// operation that takes 'input' as the only operand, and produces a single
168   /// result of 'resultType'. If a conversion can not be generated, nullptr
169   /// should be returned.
170   Operation *materializeCallConversion(OpBuilder &builder, Value input,
171                                        Type resultType,
172                                        Location conversionLoc) const final {
173     // Only allow conversion for i16/i32 types.
174     if (!(resultType.isSignlessInteger(16) ||
175           resultType.isSignlessInteger(32)) ||
176         !(input.getType().isSignlessInteger(16) ||
177           input.getType().isSignlessInteger(32)))
178       return nullptr;
179     return builder.create<TestCastOp>(conversionLoc, resultType, input);
180   }
181 
182   void processInlinedCallBlocks(
183       Operation *call,
184       iterator_range<Region::iterator> inlinedBlocks) const final {
185     if (!isa<ConversionCallOp>(call))
186       return;
187 
188     // Set attributed on all ops in the inlined blocks.
189     for (Block &block : inlinedBlocks) {
190       block.walk([&](Operation *op) {
191         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
192       });
193     }
194   }
195 };
196 
197 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
198 public:
199   TestReductionPatternInterface(Dialect *dialect)
200       : DialectReductionPatternInterface(dialect) {}
201 
202   void populateReductionPatterns(RewritePatternSet &patterns) const final {
203     populateTestReductionPatterns(patterns);
204   }
205 };
206 
207 } // namespace
208 
209 //===----------------------------------------------------------------------===//
210 // TestDialect
211 //===----------------------------------------------------------------------===//
212 
213 static void testSideEffectOpGetEffect(
214     Operation *op,
215     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
216 
217 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
218 struct TestOpEffectInterfaceFallback
219     : public TestEffectOpInterface::FallbackModel<
220           TestOpEffectInterfaceFallback> {
221   static bool classof(Operation *op) {
222     bool isSupportedOp =
223         op->getName().getStringRef() == "test.unregistered_side_effect_op";
224     assert(isSupportedOp && "Unexpected dispatch");
225     return isSupportedOp;
226   }
227 
228   void
229   getEffects(Operation *op,
230              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
231                  &effects) const {
232     testSideEffectOpGetEffect(op, effects);
233   }
234 };
235 
236 void TestDialect::initialize() {
237   registerAttributes();
238   registerTypes();
239   addOperations<
240 #define GET_OP_LIST
241 #include "TestOps.cpp.inc"
242       >();
243   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
244                 TestInlinerInterface, TestReductionPatternInterface>();
245   allowUnknownOperations();
246 
247   // Instantiate our fallback op interface that we'll use on specific
248   // unregistered op.
249   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
250 }
251 TestDialect::~TestDialect() {
252   delete static_cast<TestOpEffectInterfaceFallback *>(
253       fallbackEffectOpInterfaces);
254 }
255 
256 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
257                                             Type type, Location loc) {
258   return builder.create<TestOpConstant>(loc, type, value);
259 }
260 
261 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
262     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
263     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
264     ::mlir::RegionRange regions,
265     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
266   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
267   return ::mlir::success();
268 }
269 
270 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
271                                                OperationName opName) {
272   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
273       typeID == TypeID::get<TestEffectOpInterface>())
274     return fallbackEffectOpInterfaces;
275   return nullptr;
276 }
277 
278 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
279                                                     NamedAttribute namedAttr) {
280   if (namedAttr.getName() == "test.invalid_attr")
281     return op->emitError() << "invalid to use 'test.invalid_attr'";
282   return success();
283 }
284 
285 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
286                                                     unsigned regionIndex,
287                                                     unsigned argIndex,
288                                                     NamedAttribute namedAttr) {
289   if (namedAttr.getName() == "test.invalid_attr")
290     return op->emitError() << "invalid to use 'test.invalid_attr'";
291   return success();
292 }
293 
294 LogicalResult
295 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
296                                          unsigned resultIndex,
297                                          NamedAttribute namedAttr) {
298   if (namedAttr.getName() == "test.invalid_attr")
299     return op->emitError() << "invalid to use 'test.invalid_attr'";
300   return success();
301 }
302 
303 Optional<Dialect::ParseOpHook>
304 TestDialect::getParseOperationHook(StringRef opName) const {
305   if (opName == "test.dialect_custom_printer") {
306     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
307       return parser.parseKeyword("custom_format");
308     }};
309   }
310   if (opName == "test.dialect_custom_format_fallback") {
311     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
312       return parser.parseKeyword("custom_format_fallback");
313     }};
314   }
315   return None;
316 }
317 
318 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
319 TestDialect::getOperationPrinter(Operation *op) const {
320   StringRef opName = op->getName().getStringRef();
321   if (opName == "test.dialect_custom_printer") {
322     return [](Operation *op, OpAsmPrinter &printer) {
323       printer.getStream() << " custom_format";
324     };
325   }
326   if (opName == "test.dialect_custom_format_fallback") {
327     return [](Operation *op, OpAsmPrinter &printer) {
328       printer.getStream() << " custom_format_fallback";
329     };
330   }
331   return {};
332 }
333 
334 //===----------------------------------------------------------------------===//
335 // TestBranchOp
336 //===----------------------------------------------------------------------===//
337 
338 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
339   assert(index == 0 && "invalid successor index");
340   return SuccessorOperands(getTargetOperandsMutable());
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // TestProducingBranchOp
345 //===----------------------------------------------------------------------===//
346 
347 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
348   assert(index <= 1 && "invalid successor index");
349   if (index == 1)
350     return SuccessorOperands(getFirstOperandsMutable());
351   return SuccessorOperands(getSecondOperandsMutable());
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // TestProducingBranchOp
356 //===----------------------------------------------------------------------===//
357 
358 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
359   assert(index <= 1 && "invalid successor index");
360   if (index == 0)
361     return SuccessorOperands(0, getSuccessOperandsMutable());
362   return SuccessorOperands(1, getErrorOperandsMutable());
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // TestDialectCanonicalizerOp
367 //===----------------------------------------------------------------------===//
368 
369 static LogicalResult
370 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
371                                PatternRewriter &rewriter) {
372   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
373       op, rewriter.getI32IntegerAttr(42));
374   return success();
375 }
376 
377 void TestDialect::getCanonicalizationPatterns(
378     RewritePatternSet &results) const {
379   results.add(&dialectCanonicalizationPattern);
380 }
381 
382 //===----------------------------------------------------------------------===//
383 // TestCallOp
384 //===----------------------------------------------------------------------===//
385 
386 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
387   // Check that the callee attribute was specified.
388   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
389   if (!fnAttr)
390     return emitOpError("requires a 'callee' symbol reference attribute");
391   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
392     return emitOpError() << "'" << fnAttr.getValue()
393                          << "' does not reference a valid function";
394   return success();
395 }
396 
397 //===----------------------------------------------------------------------===//
398 // TestFoldToCallOp
399 //===----------------------------------------------------------------------===//
400 
401 namespace {
402 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
403   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
404 
405   LogicalResult matchAndRewrite(FoldToCallOp op,
406                                 PatternRewriter &rewriter) const override {
407     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
408                                               op.getCalleeAttr(), ValueRange());
409     return success();
410   }
411 };
412 } // namespace
413 
414 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
415                                                MLIRContext *context) {
416   results.add<FoldToCallOpPattern>(context);
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // Test Format* operations
421 //===----------------------------------------------------------------------===//
422 
423 //===----------------------------------------------------------------------===//
424 // Parsing
425 
426 static ParseResult parseCustomOptionalOperand(
427     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
428   if (succeeded(parser.parseOptionalLParen())) {
429     optOperand.emplace();
430     if (parser.parseOperand(*optOperand) || parser.parseRParen())
431       return failure();
432   }
433   return success();
434 }
435 
436 static ParseResult parseCustomDirectiveOperands(
437     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
438     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
439     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
440   if (parser.parseOperand(operand))
441     return failure();
442   if (succeeded(parser.parseOptionalComma())) {
443     optOperand.emplace();
444     if (parser.parseOperand(*optOperand))
445       return failure();
446   }
447   if (parser.parseArrow() || parser.parseLParen() ||
448       parser.parseOperandList(varOperands) || parser.parseRParen())
449     return failure();
450   return success();
451 }
452 static ParseResult
453 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
454                             Type &optOperandType,
455                             SmallVectorImpl<Type> &varOperandTypes) {
456   if (parser.parseColon())
457     return failure();
458 
459   if (parser.parseType(operandType))
460     return failure();
461   if (succeeded(parser.parseOptionalComma())) {
462     if (parser.parseType(optOperandType))
463       return failure();
464   }
465   if (parser.parseArrow() || parser.parseLParen() ||
466       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
467     return failure();
468   return success();
469 }
470 static ParseResult
471 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
472                                  Type optOperandType,
473                                  const SmallVectorImpl<Type> &varOperandTypes) {
474   if (parser.parseKeyword("type_refs_capture"))
475     return failure();
476 
477   Type operandType2, optOperandType2;
478   SmallVector<Type, 1> varOperandTypes2;
479   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
480                                   varOperandTypes2))
481     return failure();
482 
483   if (operandType != operandType2 || optOperandType != optOperandType2 ||
484       varOperandTypes != varOperandTypes2)
485     return failure();
486 
487   return success();
488 }
489 static ParseResult parseCustomDirectiveOperandsAndTypes(
490     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
491     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
492     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
493     Type &operandType, Type &optOperandType,
494     SmallVectorImpl<Type> &varOperandTypes) {
495   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
496       parseCustomDirectiveResults(parser, operandType, optOperandType,
497                                   varOperandTypes))
498     return failure();
499   return success();
500 }
501 static ParseResult parseCustomDirectiveRegions(
502     OpAsmParser &parser, Region &region,
503     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
504   if (parser.parseRegion(region))
505     return failure();
506   if (failed(parser.parseOptionalComma()))
507     return success();
508   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
509   if (parser.parseRegion(*varRegion))
510     return failure();
511   varRegions.emplace_back(std::move(varRegion));
512   return success();
513 }
514 static ParseResult
515 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
516                                SmallVectorImpl<Block *> &varSuccessors) {
517   if (parser.parseSuccessor(successor))
518     return failure();
519   if (failed(parser.parseOptionalComma()))
520     return success();
521   Block *varSuccessor;
522   if (parser.parseSuccessor(varSuccessor))
523     return failure();
524   varSuccessors.append(2, varSuccessor);
525   return success();
526 }
527 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
528                                                   IntegerAttr &attr,
529                                                   IntegerAttr &optAttr) {
530   if (parser.parseAttribute(attr))
531     return failure();
532   if (succeeded(parser.parseOptionalComma())) {
533     if (parser.parseAttribute(optAttr))
534       return failure();
535   }
536   return success();
537 }
538 
539 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
540                                                 NamedAttrList &attrs) {
541   return parser.parseOptionalAttrDict(attrs);
542 }
543 static ParseResult parseCustomDirectiveOptionalOperandRef(
544     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
545   int64_t operandCount = 0;
546   if (parser.parseInteger(operandCount))
547     return failure();
548   bool expectedOptionalOperand = operandCount == 0;
549   return success(expectedOptionalOperand != optOperand.hasValue());
550 }
551 
552 //===----------------------------------------------------------------------===//
553 // Printing
554 
555 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
556                                        Value optOperand) {
557   if (optOperand)
558     printer << "(" << optOperand << ") ";
559 }
560 
561 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
562                                          Value operand, Value optOperand,
563                                          OperandRange varOperands) {
564   printer << operand;
565   if (optOperand)
566     printer << ", " << optOperand;
567   printer << " -> (" << varOperands << ")";
568 }
569 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
570                                         Type operandType, Type optOperandType,
571                                         TypeRange varOperandTypes) {
572   printer << " : " << operandType;
573   if (optOperandType)
574     printer << ", " << optOperandType;
575   printer << " -> (" << varOperandTypes << ")";
576 }
577 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
578                                              Operation *op, Type operandType,
579                                              Type optOperandType,
580                                              TypeRange varOperandTypes) {
581   printer << " type_refs_capture ";
582   printCustomDirectiveResults(printer, op, operandType, optOperandType,
583                               varOperandTypes);
584 }
585 static void printCustomDirectiveOperandsAndTypes(
586     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
587     OperandRange varOperands, Type operandType, Type optOperandType,
588     TypeRange varOperandTypes) {
589   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
590   printCustomDirectiveResults(printer, op, operandType, optOperandType,
591                               varOperandTypes);
592 }
593 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
594                                         Region &region,
595                                         MutableArrayRef<Region> varRegions) {
596   printer.printRegion(region);
597   if (!varRegions.empty()) {
598     printer << ", ";
599     for (Region &region : varRegions)
600       printer.printRegion(region);
601   }
602 }
603 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
604                                            Block *successor,
605                                            SuccessorRange varSuccessors) {
606   printer << successor;
607   if (!varSuccessors.empty())
608     printer << ", " << varSuccessors.front();
609 }
610 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
611                                            Attribute attribute,
612                                            Attribute optAttribute) {
613   printer << attribute;
614   if (optAttribute)
615     printer << ", " << optAttribute;
616 }
617 
618 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
619                                          DictionaryAttr attrs) {
620   printer.printOptionalAttrDict(attrs.getValue());
621 }
622 
623 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
624                                                    Operation *op,
625                                                    Value optOperand) {
626   printer << (optOperand ? "1" : "0");
627 }
628 
629 //===----------------------------------------------------------------------===//
630 // Test IsolatedRegionOp - parse passthrough region arguments.
631 //===----------------------------------------------------------------------===//
632 
633 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
634                                     OperationState &result) {
635   OpAsmParser::UnresolvedOperand argInfo;
636   Type argType = parser.getBuilder().getIndexType();
637 
638   // Parse the input operand.
639   if (parser.parseOperand(argInfo) ||
640       parser.resolveOperand(argInfo, argType, result.operands))
641     return failure();
642 
643   // Parse the body region, and reuse the operand info as the argument info.
644   Region *body = result.addRegion();
645   return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{},
646                             /*enableNameShadowing=*/true);
647 }
648 
649 void IsolatedRegionOp::print(OpAsmPrinter &p) {
650   p << "test.isolated_region ";
651   p.printOperand(getOperand());
652   p.shadowRegionArgs(getRegion(), getOperand());
653   p << ' ';
654   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // Test SSACFGRegionOp
659 //===----------------------------------------------------------------------===//
660 
661 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
662   return RegionKind::SSACFG;
663 }
664 
665 //===----------------------------------------------------------------------===//
666 // Test GraphRegionOp
667 //===----------------------------------------------------------------------===//
668 
669 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
670   // Parse the body region, and reuse the operand info as the argument info.
671   Region *body = result.addRegion();
672   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
673 }
674 
675 void GraphRegionOp::print(OpAsmPrinter &p) {
676   p << "test.graph_region ";
677   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
678 }
679 
680 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
681   return RegionKind::Graph;
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // Test AffineScopeOp
686 //===----------------------------------------------------------------------===//
687 
688 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
689   // Parse the body region, and reuse the operand info as the argument info.
690   Region *body = result.addRegion();
691   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
692 }
693 
694 void AffineScopeOp::print(OpAsmPrinter &p) {
695   p << "test.affine_scope ";
696   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // Test parser.
701 //===----------------------------------------------------------------------===//
702 
703 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
704                                          OperationState &result) {
705   if (parser.parseOptionalColon())
706     return success();
707   uint64_t numResults;
708   if (parser.parseInteger(numResults))
709     return failure();
710 
711   IndexType type = parser.getBuilder().getIndexType();
712   for (unsigned i = 0; i < numResults; ++i)
713     result.addTypes(type);
714   return success();
715 }
716 
717 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
718   if (unsigned numResults = getNumResults())
719     p << " : " << numResults;
720 }
721 
722 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
723                                          OperationState &result) {
724   StringRef keyword;
725   if (parser.parseKeyword(&keyword))
726     return failure();
727   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
728   return success();
729 }
730 
731 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
732 
733 //===----------------------------------------------------------------------===//
734 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
735 
736 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
737                                     OperationState &result) {
738   if (parser.parseKeyword("wraps"))
739     return failure();
740 
741   // Parse the wrapped op in a region
742   Region &body = *result.addRegion();
743   body.push_back(new Block);
744   Block &block = body.back();
745   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
746   if (!wrappedOp)
747     return failure();
748 
749   // Create a return terminator in the inner region, pass as operand to the
750   // terminator the returned values from the wrapped operation.
751   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
752   OpBuilder builder(parser.getContext());
753   builder.setInsertionPointToEnd(&block);
754   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
755 
756   // Get the results type for the wrapping op from the terminator operands.
757   Operation &returnOp = body.back().back();
758   result.types.append(returnOp.operand_type_begin(),
759                       returnOp.operand_type_end());
760 
761   // Use the location of the wrapped op for the "test.wrapping_region" op.
762   result.location = wrappedOp->getLoc();
763 
764   return success();
765 }
766 
767 void WrappingRegionOp::print(OpAsmPrinter &p) {
768   p << " wraps ";
769   p.printGenericOp(&getRegion().front().front());
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
774 //   parseGenericOperationAfterOpName
775 //   parseCustomOperationName
776 //===----------------------------------------------------------------------===//
777 
778 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
779                                          OperationState &result) {
780 
781   SMLoc loc = parser.getCurrentLocation();
782   Location currLocation = parser.getEncodedSourceLoc(loc);
783 
784   // Parse the operands.
785   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
786   if (parser.parseOperandList(operands))
787     return failure();
788 
789   // Check if we are parsing the pretty-printed version
790   //  test.pretty_printed_region start <inner-op> end : <functional-type>
791   // Else fallback to parsing the "non pretty-printed" version.
792   if (!succeeded(parser.parseOptionalKeyword("start")))
793     return parser.parseGenericOperationAfterOpName(
794         result, llvm::makeArrayRef(operands));
795 
796   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
797   if (failed(parseOpNameInfo))
798     return failure();
799 
800   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
801 
802   FunctionType opFntype;
803   Optional<Location> explicitLoc;
804   if (parser.parseKeyword("end") || parser.parseColon() ||
805       parser.parseType(opFntype) ||
806       parser.parseOptionalLocationSpecifier(explicitLoc))
807     return failure();
808 
809   // If location of the op is explicitly provided, then use it; Else use
810   // the parser's current location.
811   Location opLoc = explicitLoc.getValueOr(currLocation);
812 
813   // Derive the SSA-values for op's operands.
814   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
815                              result.operands))
816     return failure();
817 
818   // Add a region for op.
819   Region &region = *result.addRegion();
820 
821   // Create a basic-block inside op's region.
822   Block &block = region.emplaceBlock();
823 
824   // Create and insert an "inner-op" operation in the block.
825   // Just for testing purposes, we can assume that inner op is a binary op with
826   // result and operand types all same as the test-op's first operand.
827   Type innerOpType = opFntype.getInput(0);
828   Value lhs = block.addArgument(innerOpType, opLoc);
829   Value rhs = block.addArgument(innerOpType, opLoc);
830 
831   OpBuilder builder(parser.getBuilder().getContext());
832   builder.setInsertionPointToStart(&block);
833 
834   Operation *innerOp =
835       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
836 
837   // Insert a return statement in the block returning the inner-op's result.
838   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
839 
840   // Populate the op operation-state with result-type and location.
841   result.addTypes(opFntype.getResults());
842   result.location = innerOp->getLoc();
843 
844   return success();
845 }
846 
847 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
848   p << ' ';
849   p.printOperands(getOperands());
850 
851   Operation &innerOp = getRegion().front().front();
852   // Assuming that region has a single non-terminator inner-op, if the inner-op
853   // meets some criteria (which in this case is a simple one  based on the name
854   // of inner-op), then we can print the entire region in a succinct way.
855   // Here we assume that the prototype of "special.op" can be trivially derived
856   // while parsing it back.
857   if (innerOp.getName().getStringRef().equals("special.op")) {
858     p << " start special.op end";
859   } else {
860     p << " (";
861     p.printRegion(getRegion());
862     p << ")";
863   }
864 
865   p << " : ";
866   p.printFunctionalType(*this);
867 }
868 
869 //===----------------------------------------------------------------------===//
870 // Test PolyForOp - parse list of region arguments.
871 //===----------------------------------------------------------------------===//
872 
873 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
874   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
875   // Parse list of region arguments without a delimiter.
876   if (parser.parseRegionArgumentList(ivsInfo))
877     return failure();
878 
879   // Parse the body region.
880   Region *body = result.addRegion();
881   auto &builder = parser.getBuilder();
882   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
883   return parser.parseRegion(*body, ivsInfo, argTypes);
884 }
885 
886 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
887 
888 void PolyForOp::getAsmBlockArgumentNames(Region &region,
889                                          OpAsmSetValueNameFn setNameFn) {
890   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
891   if (!arrayAttr)
892     return;
893   auto args = getRegion().front().getArguments();
894   auto e = std::min(arrayAttr.size(), args.size());
895   for (unsigned i = 0; i < e; ++i) {
896     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
897       setNameFn(args[i], strAttr.getValue());
898   }
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // Test removing op with inner ops.
903 //===----------------------------------------------------------------------===//
904 
905 namespace {
906 struct TestRemoveOpWithInnerOps
907     : public OpRewritePattern<TestOpWithRegionPattern> {
908   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
909 
910   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
911 
912   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
913                                 PatternRewriter &rewriter) const override {
914     rewriter.eraseOp(op);
915     return success();
916   }
917 };
918 } // namespace
919 
920 void TestOpWithRegionPattern::getCanonicalizationPatterns(
921     RewritePatternSet &results, MLIRContext *context) {
922   results.add<TestRemoveOpWithInnerOps>(context);
923 }
924 
925 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
926   return getOperand();
927 }
928 
929 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
930   return getValue();
931 }
932 
933 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
934     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
935   for (Value input : this->getOperands()) {
936     results.push_back(input);
937   }
938   return success();
939 }
940 
941 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
942   assert(operands.size() == 1);
943   if (operands.front()) {
944     (*this)->setAttr("attr", operands.front());
945     return getResult();
946   }
947   return {};
948 }
949 
950 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
951   return getOperand();
952 }
953 
954 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
955     MLIRContext *, Optional<Location> location, ValueRange operands,
956     DictionaryAttr attributes, RegionRange regions,
957     SmallVectorImpl<Type> &inferredReturnTypes) {
958   if (operands[0].getType() != operands[1].getType()) {
959     return emitOptionalError(location, "operand type mismatch ",
960                              operands[0].getType(), " vs ",
961                              operands[1].getType());
962   }
963   inferredReturnTypes.assign({operands[0].getType()});
964   return success();
965 }
966 
967 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
968     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
969     DictionaryAttr attributes, RegionRange regions,
970     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
971   // Create return type consisting of the last element of the first operand.
972   auto operandType = operands.front().getType();
973   auto sval = operandType.dyn_cast<ShapedType>();
974   if (!sval) {
975     return emitOptionalError(location, "only shaped type operands allowed");
976   }
977   int64_t dim =
978       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
979   auto type = IntegerType::get(context, 17);
980   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
981   return success();
982 }
983 
984 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
985     OpBuilder &builder, ValueRange operands,
986     llvm::SmallVectorImpl<Value> &shapes) {
987   shapes = SmallVector<Value, 1>{
988       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
989   return success();
990 }
991 
992 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
993     OpBuilder &builder, ValueRange operands,
994     llvm::SmallVectorImpl<Value> &shapes) {
995   Location loc = getLoc();
996   shapes.reserve(operands.size());
997   for (Value operand : llvm::reverse(operands)) {
998     auto rank = operand.getType().cast<RankedTensorType>().getRank();
999     auto currShape = llvm::to_vector<4>(
1000         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1001           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1002         }));
1003     shapes.push_back(builder.create<tensor::FromElementsOp>(
1004         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1005         currShape));
1006   }
1007   return success();
1008 }
1009 
1010 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1011     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1012   Location loc = getLoc();
1013   shapes.reserve(getNumOperands());
1014   for (Value operand : llvm::reverse(getOperands())) {
1015     auto currShape = llvm::to_vector<4>(llvm::map_range(
1016         llvm::seq<int64_t>(
1017             0, operand.getType().cast<RankedTensorType>().getRank()),
1018         [&](int64_t dim) -> Value {
1019           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1020         }));
1021     shapes.emplace_back(std::move(currShape));
1022   }
1023   return success();
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // Test SideEffect interfaces
1028 //===----------------------------------------------------------------------===//
1029 
1030 namespace {
1031 /// A test resource for side effects.
1032 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1033   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1034 
1035   StringRef getName() final { return "<Test>"; }
1036 };
1037 } // namespace
1038 
1039 static void testSideEffectOpGetEffect(
1040     Operation *op,
1041     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1042         &effects) {
1043   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1044   if (!effectsAttr)
1045     return;
1046 
1047   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1048 }
1049 
1050 void SideEffectOp::getEffects(
1051     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1052   // Check for an effects attribute on the op instance.
1053   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1054   if (!effectsAttr)
1055     return;
1056 
1057   // If there is one, it is an array of dictionary attributes that hold
1058   // information on the effects of this operation.
1059   for (Attribute element : effectsAttr) {
1060     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1061 
1062     // Get the specific memory effect.
1063     MemoryEffects::Effect *effect =
1064         StringSwitch<MemoryEffects::Effect *>(
1065             effectElement.get("effect").cast<StringAttr>().getValue())
1066             .Case("allocate", MemoryEffects::Allocate::get())
1067             .Case("free", MemoryEffects::Free::get())
1068             .Case("read", MemoryEffects::Read::get())
1069             .Case("write", MemoryEffects::Write::get());
1070 
1071     // Check for a non-default resource to use.
1072     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1073     if (effectElement.get("test_resource"))
1074       resource = TestResource::get();
1075 
1076     // Check for a result to affect.
1077     if (effectElement.get("on_result"))
1078       effects.emplace_back(effect, getResult(), resource);
1079     else if (Attribute ref = effectElement.get("on_reference"))
1080       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1081     else
1082       effects.emplace_back(effect, resource);
1083   }
1084 }
1085 
1086 void SideEffectOp::getEffects(
1087     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1088   testSideEffectOpGetEffect(getOperation(), effects);
1089 }
1090 
1091 //===----------------------------------------------------------------------===//
1092 // StringAttrPrettyNameOp
1093 //===----------------------------------------------------------------------===//
1094 
1095 // This op has fancy handling of its SSA result name.
1096 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1097                                           OperationState &result) {
1098   // Add the result types.
1099   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1100     result.addTypes(parser.getBuilder().getIntegerType(32));
1101 
1102   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1103     return failure();
1104 
1105   // If the attribute dictionary contains no 'names' attribute, infer it from
1106   // the SSA name (if specified).
1107   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1108     return attr.getName() == "names";
1109   });
1110 
1111   // If there was no name specified, check to see if there was a useful name
1112   // specified in the asm file.
1113   if (hadNames || parser.getNumResults() == 0)
1114     return success();
1115 
1116   SmallVector<StringRef, 4> names;
1117   auto *context = result.getContext();
1118 
1119   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1120     auto resultName = parser.getResultName(i);
1121     StringRef nameStr;
1122     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1123       nameStr = resultName.first;
1124 
1125     names.push_back(nameStr);
1126   }
1127 
1128   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1129   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1130   return success();
1131 }
1132 
1133 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1134   // Note that we only need to print the "name" attribute if the asmprinter
1135   // result name disagrees with it.  This can happen in strange cases, e.g.
1136   // when there are conflicts.
1137   bool namesDisagree = getNames().size() != getNumResults();
1138 
1139   SmallString<32> resultNameStr;
1140   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1141     resultNameStr.clear();
1142     llvm::raw_svector_ostream tmpStream(resultNameStr);
1143     p.printOperand(getResult(i), tmpStream);
1144 
1145     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1146     if (!expectedName ||
1147         tmpStream.str().drop_front() != expectedName.getValue()) {
1148       namesDisagree = true;
1149     }
1150   }
1151 
1152   if (namesDisagree)
1153     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1154   else
1155     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1156 }
1157 
1158 // We set the SSA name in the asm syntax to the contents of the name
1159 // attribute.
1160 void StringAttrPrettyNameOp::getAsmResultNames(
1161     function_ref<void(Value, StringRef)> setNameFn) {
1162 
1163   auto value = getNames();
1164   for (size_t i = 0, e = value.size(); i != e; ++i)
1165     if (auto str = value[i].dyn_cast<StringAttr>())
1166       if (!str.getValue().empty())
1167         setNameFn(getResult(i), str.getValue());
1168 }
1169 
1170 //===----------------------------------------------------------------------===//
1171 // ResultTypeWithTraitOp
1172 //===----------------------------------------------------------------------===//
1173 
1174 LogicalResult ResultTypeWithTraitOp::verify() {
1175   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1176     return success();
1177   return emitError("result type should have trait 'TestTypeTrait'");
1178 }
1179 
1180 //===----------------------------------------------------------------------===//
1181 // AttrWithTraitOp
1182 //===----------------------------------------------------------------------===//
1183 
1184 LogicalResult AttrWithTraitOp::verify() {
1185   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1186     return success();
1187   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1188 }
1189 
1190 //===----------------------------------------------------------------------===//
1191 // RegionIfOp
1192 //===----------------------------------------------------------------------===//
1193 
1194 void RegionIfOp::print(OpAsmPrinter &p) {
1195   p << " ";
1196   p.printOperands(getOperands());
1197   p << ": " << getOperandTypes();
1198   p.printArrowTypeList(getResultTypes());
1199   p << " then ";
1200   p.printRegion(getThenRegion(),
1201                 /*printEntryBlockArgs=*/true,
1202                 /*printBlockTerminators=*/true);
1203   p << " else ";
1204   p.printRegion(getElseRegion(),
1205                 /*printEntryBlockArgs=*/true,
1206                 /*printBlockTerminators=*/true);
1207   p << " join ";
1208   p.printRegion(getJoinRegion(),
1209                 /*printEntryBlockArgs=*/true,
1210                 /*printBlockTerminators=*/true);
1211 }
1212 
1213 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1214   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1215   SmallVector<Type, 2> operandTypes;
1216 
1217   result.regions.reserve(3);
1218   Region *thenRegion = result.addRegion();
1219   Region *elseRegion = result.addRegion();
1220   Region *joinRegion = result.addRegion();
1221 
1222   // Parse operand, type and arrow type lists.
1223   if (parser.parseOperandList(operandInfos) ||
1224       parser.parseColonTypeList(operandTypes) ||
1225       parser.parseArrowTypeList(result.types))
1226     return failure();
1227 
1228   // Parse all attached regions.
1229   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1230       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1231       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1232     return failure();
1233 
1234   return parser.resolveOperands(operandInfos, operandTypes,
1235                                 parser.getCurrentLocation(), result.operands);
1236 }
1237 
1238 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1239   assert(index < 2 && "invalid region index");
1240   return getOperands();
1241 }
1242 
1243 void RegionIfOp::getSuccessorRegions(
1244     Optional<unsigned> index, ArrayRef<Attribute> operands,
1245     SmallVectorImpl<RegionSuccessor> &regions) {
1246   // We always branch to the join region.
1247   if (index.hasValue()) {
1248     if (index.getValue() < 2)
1249       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1250     else
1251       regions.push_back(RegionSuccessor(getResults()));
1252     return;
1253   }
1254 
1255   // The then and else regions are the entry regions of this op.
1256   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1257   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1258 }
1259 
1260 void RegionIfOp::getRegionInvocationBounds(
1261     ArrayRef<Attribute> operands,
1262     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1263   // Each region is invoked at most once.
1264   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1265 }
1266 
1267 //===----------------------------------------------------------------------===//
1268 // AnyCondOp
1269 //===----------------------------------------------------------------------===//
1270 
1271 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1272                                     ArrayRef<Attribute> operands,
1273                                     SmallVectorImpl<RegionSuccessor> &regions) {
1274   // The parent op branches into the only region, and the region branches back
1275   // to the parent op.
1276   if (index)
1277     regions.emplace_back(&getRegion());
1278   else
1279     regions.emplace_back(getResults());
1280 }
1281 
1282 void AnyCondOp::getRegionInvocationBounds(
1283     ArrayRef<Attribute> operands,
1284     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1285   invocationBounds.emplace_back(1, 1);
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // SingleNoTerminatorCustomAsmOp
1290 //===----------------------------------------------------------------------===//
1291 
1292 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1293                                                  OperationState &state) {
1294   Region *body = state.addRegion();
1295   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1296     return failure();
1297   return success();
1298 }
1299 
1300 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1301   printer.printRegion(
1302       getRegion(), /*printEntryBlockArgs=*/false,
1303       // This op has a single block without terminators. But explicitly mark
1304       // as not printing block terminators for testing.
1305       /*printBlockTerminators=*/false);
1306 }
1307 
1308 #include "TestOpEnums.cpp.inc"
1309 #include "TestOpInterfaces.cpp.inc"
1310 #include "TestOpStructs.cpp.inc"
1311 #include "TestTypeInterfaces.cpp.inc"
1312 
1313 #define GET_OP_CLASSES
1314 #include "TestOps.cpp.inc"
1315