1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/IR/Verifier.h"
22 #include "mlir/Reducer/ReductionPatternInterface.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/StringSwitch.h"
26 
27 // Include this before the using namespace lines below to
28 // test that we don't have namespace dependencies.
29 #include "TestOpsDialect.cpp.inc"
30 
31 using namespace mlir;
32 using namespace test;
33 
34 void test::registerTestDialect(DialectRegistry &registry) {
35   registry.insert<TestDialect>();
36 }
37 
38 //===----------------------------------------------------------------------===//
39 // TestDialect Interfaces
40 //===----------------------------------------------------------------------===//
41 
42 namespace {
43 
44 /// Testing the correctness of some traits.
45 static_assert(
46     llvm::is_detected<OpTrait::has_implicit_terminator_t,
47                       SingleBlockImplicitTerminatorOp>::value,
48     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
49 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
50                   SingleBlockImplicitTerminatorOp>::value,
51               "hasSingleBlockImplicitTerminator does not match "
52               "SingleBlockImplicitTerminatorOp");
53 
54 // Test support for interacting with the AsmPrinter.
55 struct TestOpAsmInterface : public OpAsmDialectInterface {
56   using OpAsmDialectInterface::OpAsmDialectInterface;
57 
58   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
59     StringAttr strAttr = attr.dyn_cast<StringAttr>();
60     if (!strAttr)
61       return AliasResult::NoAlias;
62 
63     // Check the contents of the string attribute to see what the test alias
64     // should be named.
65     Optional<StringRef> aliasName =
66         StringSwitch<Optional<StringRef>>(strAttr.getValue())
67             .Case("alias_test:dot_in_name", StringRef("test.alias"))
68             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
69             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
70             .Case("alias_test:sanitize_conflict_a",
71                   StringRef("test_alias_conflict0"))
72             .Case("alias_test:sanitize_conflict_b",
73                   StringRef("test_alias_conflict0_"))
74             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
75             .Default(llvm::None);
76     if (!aliasName)
77       return AliasResult::NoAlias;
78 
79     os << *aliasName;
80     return AliasResult::FinalAlias;
81   }
82 
83   AliasResult getAlias(Type type, raw_ostream &os) const final {
84     if (auto tupleType = type.dyn_cast<TupleType>()) {
85       if (tupleType.size() > 0 &&
86           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
87             return elemType.isa<SimpleAType>();
88           })) {
89         os << "test_tuple";
90         return AliasResult::FinalAlias;
91       }
92     }
93     if (auto intType = type.dyn_cast<TestIntegerType>()) {
94       if (intType.getSignedness() ==
95               TestIntegerType::SignednessSemantics::Unsigned &&
96           intType.getWidth() == 8) {
97         os << "test_ui8";
98         return AliasResult::FinalAlias;
99       }
100     }
101     return AliasResult::NoAlias;
102   }
103 };
104 
105 struct TestDialectFoldInterface : public DialectFoldInterface {
106   using DialectFoldInterface::DialectFoldInterface;
107 
108   /// Registered hook to check if the given region, which is attached to an
109   /// operation that is *not* isolated from above, should be used when
110   /// materializing constants.
111   bool shouldMaterializeInto(Region *region) const final {
112     // If this is a one region operation, then insert into it.
113     return isa<OneRegionOp>(region->getParentOp());
114   }
115 };
116 
117 /// This class defines the interface for handling inlining with standard
118 /// operations.
119 struct TestInlinerInterface : public DialectInlinerInterface {
120   using DialectInlinerInterface::DialectInlinerInterface;
121 
122   //===--------------------------------------------------------------------===//
123   // Analysis Hooks
124   //===--------------------------------------------------------------------===//
125 
126   bool isLegalToInline(Operation *call, Operation *callable,
127                        bool wouldBeCloned) const final {
128     // Don't allow inlining calls that are marked `noinline`.
129     return !call->hasAttr("noinline");
130   }
131   bool isLegalToInline(Region *, Region *, bool,
132                        BlockAndValueMapping &) const final {
133     // Inlining into test dialect regions is legal.
134     return true;
135   }
136   bool isLegalToInline(Operation *, Region *, bool,
137                        BlockAndValueMapping &) const final {
138     return true;
139   }
140 
141   bool shouldAnalyzeRecursively(Operation *op) const final {
142     // Analyze recursively if this is not a functional region operation, it
143     // froms a separate functional scope.
144     return !isa<FunctionalRegionOp>(op);
145   }
146 
147   //===--------------------------------------------------------------------===//
148   // Transformation Hooks
149   //===--------------------------------------------------------------------===//
150 
151   /// Handle the given inlined terminator by replacing it with a new operation
152   /// as necessary.
153   void handleTerminator(Operation *op,
154                         ArrayRef<Value> valuesToRepl) const final {
155     // Only handle "test.return" here.
156     auto returnOp = dyn_cast<TestReturnOp>(op);
157     if (!returnOp)
158       return;
159 
160     // Replace the values directly with the return operands.
161     assert(returnOp.getNumOperands() == valuesToRepl.size());
162     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
163       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
164   }
165 
166   /// Attempt to materialize a conversion for a type mismatch between a call
167   /// from this dialect, and a callable region. This method should generate an
168   /// operation that takes 'input' as the only operand, and produces a single
169   /// result of 'resultType'. If a conversion can not be generated, nullptr
170   /// should be returned.
171   Operation *materializeCallConversion(OpBuilder &builder, Value input,
172                                        Type resultType,
173                                        Location conversionLoc) const final {
174     // Only allow conversion for i16/i32 types.
175     if (!(resultType.isSignlessInteger(16) ||
176           resultType.isSignlessInteger(32)) ||
177         !(input.getType().isSignlessInteger(16) ||
178           input.getType().isSignlessInteger(32)))
179       return nullptr;
180     return builder.create<TestCastOp>(conversionLoc, resultType, input);
181   }
182 
183   void processInlinedCallBlocks(
184       Operation *call,
185       iterator_range<Region::iterator> inlinedBlocks) const final {
186     if (!isa<ConversionCallOp>(call))
187       return;
188 
189     // Set attributed on all ops in the inlined blocks.
190     for (Block &block : inlinedBlocks) {
191       block.walk([&](Operation *op) {
192         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
193       });
194     }
195   }
196 };
197 
198 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
199 public:
200   TestReductionPatternInterface(Dialect *dialect)
201       : DialectReductionPatternInterface(dialect) {}
202 
203   void populateReductionPatterns(RewritePatternSet &patterns) const final {
204     populateTestReductionPatterns(patterns);
205   }
206 };
207 
208 } // namespace
209 
210 //===----------------------------------------------------------------------===//
211 // TestDialect
212 //===----------------------------------------------------------------------===//
213 
214 static void testSideEffectOpGetEffect(
215     Operation *op,
216     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
217 
218 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
219 struct TestOpEffectInterfaceFallback
220     : public TestEffectOpInterface::FallbackModel<
221           TestOpEffectInterfaceFallback> {
222   static bool classof(Operation *op) {
223     bool isSupportedOp =
224         op->getName().getStringRef() == "test.unregistered_side_effect_op";
225     assert(isSupportedOp && "Unexpected dispatch");
226     return isSupportedOp;
227   }
228 
229   void
230   getEffects(Operation *op,
231              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
232                  &effects) const {
233     testSideEffectOpGetEffect(op, effects);
234   }
235 };
236 
237 void TestDialect::initialize() {
238   registerAttributes();
239   registerTypes();
240   addOperations<
241 #define GET_OP_LIST
242 #include "TestOps.cpp.inc"
243       >();
244   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
245                 TestInlinerInterface, TestReductionPatternInterface>();
246   allowUnknownOperations();
247 
248   // Instantiate our fallback op interface that we'll use on specific
249   // unregistered op.
250   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
251 }
252 TestDialect::~TestDialect() {
253   delete static_cast<TestOpEffectInterfaceFallback *>(
254       fallbackEffectOpInterfaces);
255 }
256 
257 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
258                                             Type type, Location loc) {
259   return builder.create<TestOpConstant>(loc, type, value);
260 }
261 
262 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
263     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
264     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
265     ::mlir::RegionRange regions,
266     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
267   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
268   return ::mlir::success();
269 }
270 
271 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
272                                                OperationName opName) {
273   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
274       typeID == TypeID::get<TestEffectOpInterface>())
275     return fallbackEffectOpInterfaces;
276   return nullptr;
277 }
278 
279 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
280                                                     NamedAttribute namedAttr) {
281   if (namedAttr.getName() == "test.invalid_attr")
282     return op->emitError() << "invalid to use 'test.invalid_attr'";
283   return success();
284 }
285 
286 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
287                                                     unsigned regionIndex,
288                                                     unsigned argIndex,
289                                                     NamedAttribute namedAttr) {
290   if (namedAttr.getName() == "test.invalid_attr")
291     return op->emitError() << "invalid to use 'test.invalid_attr'";
292   return success();
293 }
294 
295 LogicalResult
296 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
297                                          unsigned resultIndex,
298                                          NamedAttribute namedAttr) {
299   if (namedAttr.getName() == "test.invalid_attr")
300     return op->emitError() << "invalid to use 'test.invalid_attr'";
301   return success();
302 }
303 
304 Optional<Dialect::ParseOpHook>
305 TestDialect::getParseOperationHook(StringRef opName) const {
306   if (opName == "test.dialect_custom_printer") {
307     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
308       return parser.parseKeyword("custom_format");
309     }};
310   }
311   if (opName == "test.dialect_custom_format_fallback") {
312     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
313       return parser.parseKeyword("custom_format_fallback");
314     }};
315   }
316   return None;
317 }
318 
319 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
320 TestDialect::getOperationPrinter(Operation *op) const {
321   StringRef opName = op->getName().getStringRef();
322   if (opName == "test.dialect_custom_printer") {
323     return [](Operation *op, OpAsmPrinter &printer) {
324       printer.getStream() << " custom_format";
325     };
326   }
327   if (opName == "test.dialect_custom_format_fallback") {
328     return [](Operation *op, OpAsmPrinter &printer) {
329       printer.getStream() << " custom_format_fallback";
330     };
331   }
332   return {};
333 }
334 
335 //===----------------------------------------------------------------------===//
336 // TestBranchOp
337 //===----------------------------------------------------------------------===//
338 
339 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
340   assert(index == 0 && "invalid successor index");
341   return SuccessorOperands(getTargetOperandsMutable());
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // TestProducingBranchOp
346 //===----------------------------------------------------------------------===//
347 
348 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
349   assert(index <= 1 && "invalid successor index");
350   if (index == 1)
351     return SuccessorOperands(getFirstOperandsMutable());
352   return SuccessorOperands(getSecondOperandsMutable());
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // TestProducingBranchOp
357 //===----------------------------------------------------------------------===//
358 
359 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
360   assert(index <= 1 && "invalid successor index");
361   if (index == 0)
362     return SuccessorOperands(0, getSuccessOperandsMutable());
363   return SuccessorOperands(1, getErrorOperandsMutable());
364 }
365 
366 //===----------------------------------------------------------------------===//
367 // TestDialectCanonicalizerOp
368 //===----------------------------------------------------------------------===//
369 
370 static LogicalResult
371 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
372                                PatternRewriter &rewriter) {
373   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
374       op, rewriter.getI32IntegerAttr(42));
375   return success();
376 }
377 
378 void TestDialect::getCanonicalizationPatterns(
379     RewritePatternSet &results) const {
380   results.add(&dialectCanonicalizationPattern);
381 }
382 
383 //===----------------------------------------------------------------------===//
384 // TestCallOp
385 //===----------------------------------------------------------------------===//
386 
387 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
388   // Check that the callee attribute was specified.
389   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
390   if (!fnAttr)
391     return emitOpError("requires a 'callee' symbol reference attribute");
392   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
393     return emitOpError() << "'" << fnAttr.getValue()
394                          << "' does not reference a valid function";
395   return success();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // TestFoldToCallOp
400 //===----------------------------------------------------------------------===//
401 
402 namespace {
403 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
404   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
405 
406   LogicalResult matchAndRewrite(FoldToCallOp op,
407                                 PatternRewriter &rewriter) const override {
408     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
409                                               op.getCalleeAttr(), ValueRange());
410     return success();
411   }
412 };
413 } // namespace
414 
415 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
416                                                MLIRContext *context) {
417   results.add<FoldToCallOpPattern>(context);
418 }
419 
420 //===----------------------------------------------------------------------===//
421 // Test Format* operations
422 //===----------------------------------------------------------------------===//
423 
424 //===----------------------------------------------------------------------===//
425 // Parsing
426 
427 static ParseResult parseCustomOptionalOperand(
428     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
429   if (succeeded(parser.parseOptionalLParen())) {
430     optOperand.emplace();
431     if (parser.parseOperand(*optOperand) || parser.parseRParen())
432       return failure();
433   }
434   return success();
435 }
436 
437 static ParseResult parseCustomDirectiveOperands(
438     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
439     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
440     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
441   if (parser.parseOperand(operand))
442     return failure();
443   if (succeeded(parser.parseOptionalComma())) {
444     optOperand.emplace();
445     if (parser.parseOperand(*optOperand))
446       return failure();
447   }
448   if (parser.parseArrow() || parser.parseLParen() ||
449       parser.parseOperandList(varOperands) || parser.parseRParen())
450     return failure();
451   return success();
452 }
453 static ParseResult
454 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
455                             Type &optOperandType,
456                             SmallVectorImpl<Type> &varOperandTypes) {
457   if (parser.parseColon())
458     return failure();
459 
460   if (parser.parseType(operandType))
461     return failure();
462   if (succeeded(parser.parseOptionalComma())) {
463     if (parser.parseType(optOperandType))
464       return failure();
465   }
466   if (parser.parseArrow() || parser.parseLParen() ||
467       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
468     return failure();
469   return success();
470 }
471 static ParseResult
472 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
473                                  Type optOperandType,
474                                  const SmallVectorImpl<Type> &varOperandTypes) {
475   if (parser.parseKeyword("type_refs_capture"))
476     return failure();
477 
478   Type operandType2, optOperandType2;
479   SmallVector<Type, 1> varOperandTypes2;
480   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
481                                   varOperandTypes2))
482     return failure();
483 
484   if (operandType != operandType2 || optOperandType != optOperandType2 ||
485       varOperandTypes != varOperandTypes2)
486     return failure();
487 
488   return success();
489 }
490 static ParseResult parseCustomDirectiveOperandsAndTypes(
491     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
492     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
493     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
494     Type &operandType, Type &optOperandType,
495     SmallVectorImpl<Type> &varOperandTypes) {
496   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
497       parseCustomDirectiveResults(parser, operandType, optOperandType,
498                                   varOperandTypes))
499     return failure();
500   return success();
501 }
502 static ParseResult parseCustomDirectiveRegions(
503     OpAsmParser &parser, Region &region,
504     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
505   if (parser.parseRegion(region))
506     return failure();
507   if (failed(parser.parseOptionalComma()))
508     return success();
509   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
510   if (parser.parseRegion(*varRegion))
511     return failure();
512   varRegions.emplace_back(std::move(varRegion));
513   return success();
514 }
515 static ParseResult
516 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
517                                SmallVectorImpl<Block *> &varSuccessors) {
518   if (parser.parseSuccessor(successor))
519     return failure();
520   if (failed(parser.parseOptionalComma()))
521     return success();
522   Block *varSuccessor;
523   if (parser.parseSuccessor(varSuccessor))
524     return failure();
525   varSuccessors.append(2, varSuccessor);
526   return success();
527 }
528 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
529                                                   IntegerAttr &attr,
530                                                   IntegerAttr &optAttr) {
531   if (parser.parseAttribute(attr))
532     return failure();
533   if (succeeded(parser.parseOptionalComma())) {
534     if (parser.parseAttribute(optAttr))
535       return failure();
536   }
537   return success();
538 }
539 
540 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
541                                                 NamedAttrList &attrs) {
542   return parser.parseOptionalAttrDict(attrs);
543 }
544 static ParseResult parseCustomDirectiveOptionalOperandRef(
545     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
546   int64_t operandCount = 0;
547   if (parser.parseInteger(operandCount))
548     return failure();
549   bool expectedOptionalOperand = operandCount == 0;
550   return success(expectedOptionalOperand != optOperand.hasValue());
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // Printing
555 
556 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
557                                        Value optOperand) {
558   if (optOperand)
559     printer << "(" << optOperand << ") ";
560 }
561 
562 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
563                                          Value operand, Value optOperand,
564                                          OperandRange varOperands) {
565   printer << operand;
566   if (optOperand)
567     printer << ", " << optOperand;
568   printer << " -> (" << varOperands << ")";
569 }
570 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
571                                         Type operandType, Type optOperandType,
572                                         TypeRange varOperandTypes) {
573   printer << " : " << operandType;
574   if (optOperandType)
575     printer << ", " << optOperandType;
576   printer << " -> (" << varOperandTypes << ")";
577 }
578 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
579                                              Operation *op, Type operandType,
580                                              Type optOperandType,
581                                              TypeRange varOperandTypes) {
582   printer << " type_refs_capture ";
583   printCustomDirectiveResults(printer, op, operandType, optOperandType,
584                               varOperandTypes);
585 }
586 static void printCustomDirectiveOperandsAndTypes(
587     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
588     OperandRange varOperands, Type operandType, Type optOperandType,
589     TypeRange varOperandTypes) {
590   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
591   printCustomDirectiveResults(printer, op, operandType, optOperandType,
592                               varOperandTypes);
593 }
594 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
595                                         Region &region,
596                                         MutableArrayRef<Region> varRegions) {
597   printer.printRegion(region);
598   if (!varRegions.empty()) {
599     printer << ", ";
600     for (Region &region : varRegions)
601       printer.printRegion(region);
602   }
603 }
604 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
605                                            Block *successor,
606                                            SuccessorRange varSuccessors) {
607   printer << successor;
608   if (!varSuccessors.empty())
609     printer << ", " << varSuccessors.front();
610 }
611 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
612                                            Attribute attribute,
613                                            Attribute optAttribute) {
614   printer << attribute;
615   if (optAttribute)
616     printer << ", " << optAttribute;
617 }
618 
619 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
620                                          DictionaryAttr attrs) {
621   printer.printOptionalAttrDict(attrs.getValue());
622 }
623 
624 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
625                                                    Operation *op,
626                                                    Value optOperand) {
627   printer << (optOperand ? "1" : "0");
628 }
629 
630 //===----------------------------------------------------------------------===//
631 // Test IsolatedRegionOp - parse passthrough region arguments.
632 //===----------------------------------------------------------------------===//
633 
634 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
635                                     OperationState &result) {
636   OpAsmParser::UnresolvedOperand argInfo;
637   Type argType = parser.getBuilder().getIndexType();
638 
639   // Parse the input operand.
640   if (parser.parseOperand(argInfo) ||
641       parser.resolveOperand(argInfo, argType, result.operands))
642     return failure();
643 
644   // Parse the body region, and reuse the operand info as the argument info.
645   Region *body = result.addRegion();
646   return parser.parseRegion(*body, argInfo, argType, /*argLocations=*/{},
647                             /*enableNameShadowing=*/true);
648 }
649 
650 void IsolatedRegionOp::print(OpAsmPrinter &p) {
651   p << "test.isolated_region ";
652   p.printOperand(getOperand());
653   p.shadowRegionArgs(getRegion(), getOperand());
654   p << ' ';
655   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // Test SSACFGRegionOp
660 //===----------------------------------------------------------------------===//
661 
662 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
663   return RegionKind::SSACFG;
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // Test GraphRegionOp
668 //===----------------------------------------------------------------------===//
669 
670 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
671   // Parse the body region, and reuse the operand info as the argument info.
672   Region *body = result.addRegion();
673   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
674 }
675 
676 void GraphRegionOp::print(OpAsmPrinter &p) {
677   p << "test.graph_region ";
678   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
679 }
680 
681 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
682   return RegionKind::Graph;
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // Test AffineScopeOp
687 //===----------------------------------------------------------------------===//
688 
689 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
690   // Parse the body region, and reuse the operand info as the argument info.
691   Region *body = result.addRegion();
692   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
693 }
694 
695 void AffineScopeOp::print(OpAsmPrinter &p) {
696   p << "test.affine_scope ";
697   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // Test parser.
702 //===----------------------------------------------------------------------===//
703 
704 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
705                                          OperationState &result) {
706   if (parser.parseOptionalColon())
707     return success();
708   uint64_t numResults;
709   if (parser.parseInteger(numResults))
710     return failure();
711 
712   IndexType type = parser.getBuilder().getIndexType();
713   for (unsigned i = 0; i < numResults; ++i)
714     result.addTypes(type);
715   return success();
716 }
717 
718 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
719   if (unsigned numResults = getNumResults())
720     p << " : " << numResults;
721 }
722 
723 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
724                                          OperationState &result) {
725   StringRef keyword;
726   if (parser.parseKeyword(&keyword))
727     return failure();
728   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
729   return success();
730 }
731 
732 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
733 
734 //===----------------------------------------------------------------------===//
735 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
736 
737 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
738                                     OperationState &result) {
739   if (parser.parseKeyword("wraps"))
740     return failure();
741 
742   // Parse the wrapped op in a region
743   Region &body = *result.addRegion();
744   body.push_back(new Block);
745   Block &block = body.back();
746   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
747   if (!wrappedOp)
748     return failure();
749 
750   // Create a return terminator in the inner region, pass as operand to the
751   // terminator the returned values from the wrapped operation.
752   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
753   OpBuilder builder(parser.getContext());
754   builder.setInsertionPointToEnd(&block);
755   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
756 
757   // Get the results type for the wrapping op from the terminator operands.
758   Operation &returnOp = body.back().back();
759   result.types.append(returnOp.operand_type_begin(),
760                       returnOp.operand_type_end());
761 
762   // Use the location of the wrapped op for the "test.wrapping_region" op.
763   result.location = wrappedOp->getLoc();
764 
765   return success();
766 }
767 
768 void WrappingRegionOp::print(OpAsmPrinter &p) {
769   p << " wraps ";
770   p.printGenericOp(&getRegion().front().front());
771 }
772 
773 //===----------------------------------------------------------------------===//
774 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
775 //   parseGenericOperationAfterOpName
776 //   parseCustomOperationName
777 //===----------------------------------------------------------------------===//
778 
779 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
780                                          OperationState &result) {
781 
782   SMLoc loc = parser.getCurrentLocation();
783   Location currLocation = parser.getEncodedSourceLoc(loc);
784 
785   // Parse the operands.
786   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
787   if (parser.parseOperandList(operands))
788     return failure();
789 
790   // Check if we are parsing the pretty-printed version
791   //  test.pretty_printed_region start <inner-op> end : <functional-type>
792   // Else fallback to parsing the "non pretty-printed" version.
793   if (!succeeded(parser.parseOptionalKeyword("start")))
794     return parser.parseGenericOperationAfterOpName(
795         result, llvm::makeArrayRef(operands));
796 
797   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
798   if (failed(parseOpNameInfo))
799     return failure();
800 
801   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
802 
803   FunctionType opFntype;
804   Optional<Location> explicitLoc;
805   if (parser.parseKeyword("end") || parser.parseColon() ||
806       parser.parseType(opFntype) ||
807       parser.parseOptionalLocationSpecifier(explicitLoc))
808     return failure();
809 
810   // If location of the op is explicitly provided, then use it; Else use
811   // the parser's current location.
812   Location opLoc = explicitLoc.getValueOr(currLocation);
813 
814   // Derive the SSA-values for op's operands.
815   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
816                              result.operands))
817     return failure();
818 
819   // Add a region for op.
820   Region &region = *result.addRegion();
821 
822   // Create a basic-block inside op's region.
823   Block &block = region.emplaceBlock();
824 
825   // Create and insert an "inner-op" operation in the block.
826   // Just for testing purposes, we can assume that inner op is a binary op with
827   // result and operand types all same as the test-op's first operand.
828   Type innerOpType = opFntype.getInput(0);
829   Value lhs = block.addArgument(innerOpType, opLoc);
830   Value rhs = block.addArgument(innerOpType, opLoc);
831 
832   OpBuilder builder(parser.getBuilder().getContext());
833   builder.setInsertionPointToStart(&block);
834 
835   Operation *innerOp =
836       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
837 
838   // Insert a return statement in the block returning the inner-op's result.
839   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
840 
841   // Populate the op operation-state with result-type and location.
842   result.addTypes(opFntype.getResults());
843   result.location = innerOp->getLoc();
844 
845   return success();
846 }
847 
848 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
849   p << ' ';
850   p.printOperands(getOperands());
851 
852   Operation &innerOp = getRegion().front().front();
853   // Assuming that region has a single non-terminator inner-op, if the inner-op
854   // meets some criteria (which in this case is a simple one  based on the name
855   // of inner-op), then we can print the entire region in a succinct way.
856   // Here we assume that the prototype of "special.op" can be trivially derived
857   // while parsing it back.
858   if (innerOp.getName().getStringRef().equals("special.op")) {
859     p << " start special.op end";
860   } else {
861     p << " (";
862     p.printRegion(getRegion());
863     p << ")";
864   }
865 
866   p << " : ";
867   p.printFunctionalType(*this);
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // Test PolyForOp - parse list of region arguments.
872 //===----------------------------------------------------------------------===//
873 
874 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
875   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
876   // Parse list of region arguments without a delimiter.
877   if (parser.parseRegionArgumentList(ivsInfo))
878     return failure();
879 
880   // Parse the body region.
881   Region *body = result.addRegion();
882   auto &builder = parser.getBuilder();
883   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
884   return parser.parseRegion(*body, ivsInfo, argTypes);
885 }
886 
887 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
888 
889 void PolyForOp::getAsmBlockArgumentNames(Region &region,
890                                          OpAsmSetValueNameFn setNameFn) {
891   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
892   if (!arrayAttr)
893     return;
894   auto args = getRegion().front().getArguments();
895   auto e = std::min(arrayAttr.size(), args.size());
896   for (unsigned i = 0; i < e; ++i) {
897     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
898       setNameFn(args[i], strAttr.getValue());
899   }
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // Test removing op with inner ops.
904 //===----------------------------------------------------------------------===//
905 
906 namespace {
907 struct TestRemoveOpWithInnerOps
908     : public OpRewritePattern<TestOpWithRegionPattern> {
909   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
910 
911   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
912 
913   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
914                                 PatternRewriter &rewriter) const override {
915     rewriter.eraseOp(op);
916     return success();
917   }
918 };
919 } // namespace
920 
921 void TestOpWithRegionPattern::getCanonicalizationPatterns(
922     RewritePatternSet &results, MLIRContext *context) {
923   results.add<TestRemoveOpWithInnerOps>(context);
924 }
925 
926 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
927   return getOperand();
928 }
929 
930 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
931   return getValue();
932 }
933 
934 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
935     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
936   for (Value input : this->getOperands()) {
937     results.push_back(input);
938   }
939   return success();
940 }
941 
942 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
943   assert(operands.size() == 1);
944   if (operands.front()) {
945     (*this)->setAttr("attr", operands.front());
946     return getResult();
947   }
948   return {};
949 }
950 
951 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
952   return getOperand();
953 }
954 
955 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
956     MLIRContext *, Optional<Location> location, ValueRange operands,
957     DictionaryAttr attributes, RegionRange regions,
958     SmallVectorImpl<Type> &inferredReturnTypes) {
959   if (operands[0].getType() != operands[1].getType()) {
960     return emitOptionalError(location, "operand type mismatch ",
961                              operands[0].getType(), " vs ",
962                              operands[1].getType());
963   }
964   inferredReturnTypes.assign({operands[0].getType()});
965   return success();
966 }
967 
968 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
969     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
970     DictionaryAttr attributes, RegionRange regions,
971     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
972   // Create return type consisting of the last element of the first operand.
973   auto operandType = operands.front().getType();
974   auto sval = operandType.dyn_cast<ShapedType>();
975   if (!sval) {
976     return emitOptionalError(location, "only shaped type operands allowed");
977   }
978   int64_t dim =
979       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
980   auto type = IntegerType::get(context, 17);
981   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
982   return success();
983 }
984 
985 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
986     OpBuilder &builder, ValueRange operands,
987     llvm::SmallVectorImpl<Value> &shapes) {
988   shapes = SmallVector<Value, 1>{
989       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
990   return success();
991 }
992 
993 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
994     OpBuilder &builder, ValueRange operands,
995     llvm::SmallVectorImpl<Value> &shapes) {
996   Location loc = getLoc();
997   shapes.reserve(operands.size());
998   for (Value operand : llvm::reverse(operands)) {
999     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1000     auto currShape = llvm::to_vector<4>(
1001         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1002           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1003         }));
1004     shapes.push_back(builder.create<tensor::FromElementsOp>(
1005         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1006         currShape));
1007   }
1008   return success();
1009 }
1010 
1011 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1012     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1013   Location loc = getLoc();
1014   shapes.reserve(getNumOperands());
1015   for (Value operand : llvm::reverse(getOperands())) {
1016     auto currShape = llvm::to_vector<4>(llvm::map_range(
1017         llvm::seq<int64_t>(
1018             0, operand.getType().cast<RankedTensorType>().getRank()),
1019         [&](int64_t dim) -> Value {
1020           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1021         }));
1022     shapes.emplace_back(std::move(currShape));
1023   }
1024   return success();
1025 }
1026 
1027 //===----------------------------------------------------------------------===//
1028 // Test SideEffect interfaces
1029 //===----------------------------------------------------------------------===//
1030 
1031 namespace {
1032 /// A test resource for side effects.
1033 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1034   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1035 
1036   StringRef getName() final { return "<Test>"; }
1037 };
1038 } // namespace
1039 
1040 static void testSideEffectOpGetEffect(
1041     Operation *op,
1042     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1043         &effects) {
1044   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1045   if (!effectsAttr)
1046     return;
1047 
1048   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1049 }
1050 
1051 void SideEffectOp::getEffects(
1052     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1053   // Check for an effects attribute on the op instance.
1054   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1055   if (!effectsAttr)
1056     return;
1057 
1058   // If there is one, it is an array of dictionary attributes that hold
1059   // information on the effects of this operation.
1060   for (Attribute element : effectsAttr) {
1061     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1062 
1063     // Get the specific memory effect.
1064     MemoryEffects::Effect *effect =
1065         StringSwitch<MemoryEffects::Effect *>(
1066             effectElement.get("effect").cast<StringAttr>().getValue())
1067             .Case("allocate", MemoryEffects::Allocate::get())
1068             .Case("free", MemoryEffects::Free::get())
1069             .Case("read", MemoryEffects::Read::get())
1070             .Case("write", MemoryEffects::Write::get());
1071 
1072     // Check for a non-default resource to use.
1073     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1074     if (effectElement.get("test_resource"))
1075       resource = TestResource::get();
1076 
1077     // Check for a result to affect.
1078     if (effectElement.get("on_result"))
1079       effects.emplace_back(effect, getResult(), resource);
1080     else if (Attribute ref = effectElement.get("on_reference"))
1081       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1082     else
1083       effects.emplace_back(effect, resource);
1084   }
1085 }
1086 
1087 void SideEffectOp::getEffects(
1088     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1089   testSideEffectOpGetEffect(getOperation(), effects);
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // StringAttrPrettyNameOp
1094 //===----------------------------------------------------------------------===//
1095 
1096 // This op has fancy handling of its SSA result name.
1097 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1098                                           OperationState &result) {
1099   // Add the result types.
1100   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1101     result.addTypes(parser.getBuilder().getIntegerType(32));
1102 
1103   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1104     return failure();
1105 
1106   // If the attribute dictionary contains no 'names' attribute, infer it from
1107   // the SSA name (if specified).
1108   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1109     return attr.getName() == "names";
1110   });
1111 
1112   // If there was no name specified, check to see if there was a useful name
1113   // specified in the asm file.
1114   if (hadNames || parser.getNumResults() == 0)
1115     return success();
1116 
1117   SmallVector<StringRef, 4> names;
1118   auto *context = result.getContext();
1119 
1120   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1121     auto resultName = parser.getResultName(i);
1122     StringRef nameStr;
1123     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1124       nameStr = resultName.first;
1125 
1126     names.push_back(nameStr);
1127   }
1128 
1129   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1130   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1131   return success();
1132 }
1133 
1134 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1135   // Note that we only need to print the "name" attribute if the asmprinter
1136   // result name disagrees with it.  This can happen in strange cases, e.g.
1137   // when there are conflicts.
1138   bool namesDisagree = getNames().size() != getNumResults();
1139 
1140   SmallString<32> resultNameStr;
1141   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1142     resultNameStr.clear();
1143     llvm::raw_svector_ostream tmpStream(resultNameStr);
1144     p.printOperand(getResult(i), tmpStream);
1145 
1146     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1147     if (!expectedName ||
1148         tmpStream.str().drop_front() != expectedName.getValue()) {
1149       namesDisagree = true;
1150     }
1151   }
1152 
1153   if (namesDisagree)
1154     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1155   else
1156     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1157 }
1158 
1159 // We set the SSA name in the asm syntax to the contents of the name
1160 // attribute.
1161 void StringAttrPrettyNameOp::getAsmResultNames(
1162     function_ref<void(Value, StringRef)> setNameFn) {
1163 
1164   auto value = getNames();
1165   for (size_t i = 0, e = value.size(); i != e; ++i)
1166     if (auto str = value[i].dyn_cast<StringAttr>())
1167       if (!str.getValue().empty())
1168         setNameFn(getResult(i), str.getValue());
1169 }
1170 
1171 //===----------------------------------------------------------------------===//
1172 // ResultTypeWithTraitOp
1173 //===----------------------------------------------------------------------===//
1174 
1175 LogicalResult ResultTypeWithTraitOp::verify() {
1176   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1177     return success();
1178   return emitError("result type should have trait 'TestTypeTrait'");
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // AttrWithTraitOp
1183 //===----------------------------------------------------------------------===//
1184 
1185 LogicalResult AttrWithTraitOp::verify() {
1186   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1187     return success();
1188   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // RegionIfOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 void RegionIfOp::print(OpAsmPrinter &p) {
1196   p << " ";
1197   p.printOperands(getOperands());
1198   p << ": " << getOperandTypes();
1199   p.printArrowTypeList(getResultTypes());
1200   p << " then ";
1201   p.printRegion(getThenRegion(),
1202                 /*printEntryBlockArgs=*/true,
1203                 /*printBlockTerminators=*/true);
1204   p << " else ";
1205   p.printRegion(getElseRegion(),
1206                 /*printEntryBlockArgs=*/true,
1207                 /*printBlockTerminators=*/true);
1208   p << " join ";
1209   p.printRegion(getJoinRegion(),
1210                 /*printEntryBlockArgs=*/true,
1211                 /*printBlockTerminators=*/true);
1212 }
1213 
1214 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1215   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1216   SmallVector<Type, 2> operandTypes;
1217 
1218   result.regions.reserve(3);
1219   Region *thenRegion = result.addRegion();
1220   Region *elseRegion = result.addRegion();
1221   Region *joinRegion = result.addRegion();
1222 
1223   // Parse operand, type and arrow type lists.
1224   if (parser.parseOperandList(operandInfos) ||
1225       parser.parseColonTypeList(operandTypes) ||
1226       parser.parseArrowTypeList(result.types))
1227     return failure();
1228 
1229   // Parse all attached regions.
1230   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1231       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1232       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1233     return failure();
1234 
1235   return parser.resolveOperands(operandInfos, operandTypes,
1236                                 parser.getCurrentLocation(), result.operands);
1237 }
1238 
1239 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1240   assert(index < 2 && "invalid region index");
1241   return getOperands();
1242 }
1243 
1244 void RegionIfOp::getSuccessorRegions(
1245     Optional<unsigned> index, ArrayRef<Attribute> operands,
1246     SmallVectorImpl<RegionSuccessor> &regions) {
1247   // We always branch to the join region.
1248   if (index.hasValue()) {
1249     if (index.getValue() < 2)
1250       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1251     else
1252       regions.push_back(RegionSuccessor(getResults()));
1253     return;
1254   }
1255 
1256   // The then and else regions are the entry regions of this op.
1257   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1258   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1259 }
1260 
1261 void RegionIfOp::getRegionInvocationBounds(
1262     ArrayRef<Attribute> operands,
1263     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1264   // Each region is invoked at most once.
1265   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1266 }
1267 
1268 //===----------------------------------------------------------------------===//
1269 // AnyCondOp
1270 //===----------------------------------------------------------------------===//
1271 
1272 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1273                                     ArrayRef<Attribute> operands,
1274                                     SmallVectorImpl<RegionSuccessor> &regions) {
1275   // The parent op branches into the only region, and the region branches back
1276   // to the parent op.
1277   if (index)
1278     regions.emplace_back(&getRegion());
1279   else
1280     regions.emplace_back(getResults());
1281 }
1282 
1283 void AnyCondOp::getRegionInvocationBounds(
1284     ArrayRef<Attribute> operands,
1285     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1286   invocationBounds.emplace_back(1, 1);
1287 }
1288 
1289 //===----------------------------------------------------------------------===//
1290 // SingleNoTerminatorCustomAsmOp
1291 //===----------------------------------------------------------------------===//
1292 
1293 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1294                                                  OperationState &state) {
1295   Region *body = state.addRegion();
1296   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1297     return failure();
1298   return success();
1299 }
1300 
1301 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1302   printer.printRegion(
1303       getRegion(), /*printEntryBlockArgs=*/false,
1304       // This op has a single block without terminators. But explicitly mark
1305       // as not printing block terminators for testing.
1306       /*printBlockTerminators=*/false);
1307 }
1308 
1309 //===----------------------------------------------------------------------===//
1310 // TestVerifiersOp
1311 //===----------------------------------------------------------------------===//
1312 
1313 LogicalResult TestVerifiersOp::verify() {
1314   if (!getRegion().hasOneBlock())
1315     return emitOpError("`hasOneBlock` trait hasn't been verified");
1316 
1317   Operation *definingOp = getInput().getDefiningOp();
1318   if (definingOp && failed(mlir::verify(definingOp)))
1319     return emitOpError("operand hasn't been verified");
1320 
1321   emitRemark("success run of verifier");
1322 
1323   return success();
1324 }
1325 
1326 LogicalResult TestVerifiersOp::verifyRegions() {
1327   if (!getRegion().hasOneBlock())
1328     return emitOpError("`hasOneBlock` trait hasn't been verified");
1329 
1330   for (Block &block : getRegion())
1331     for (Operation &op : block)
1332       if (failed(mlir::verify(&op)))
1333         return emitOpError("nested op hasn't been verified");
1334 
1335   emitRemark("success run of region verifier");
1336 
1337   return success();
1338 }
1339 
1340 #include "TestOpEnums.cpp.inc"
1341 #include "TestOpInterfaces.cpp.inc"
1342 #include "TestOpStructs.cpp.inc"
1343 #include "TestTypeInterfaces.cpp.inc"
1344 
1345 #define GET_OP_CLASSES
1346 #include "TestOps.cpp.inc"
1347