1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/DialectImplementation.h"
19 #include "mlir/IR/ExtensibleDialect.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "mlir/IR/Verifier.h"
23 #include "mlir/Reducer/ReductionPatternInterface.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/InliningUtils.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSwitch.h"
28 
29 // Include this before the using namespace lines below to
30 // test that we don't have namespace dependencies.
31 #include "TestOpsDialect.cpp.inc"
32 
33 using namespace mlir;
34 using namespace test;
35 
36 void test::registerTestDialect(DialectRegistry &registry) {
37   registry.insert<TestDialect>();
38 }
39 
40 //===----------------------------------------------------------------------===//
41 // TestDialect Interfaces
42 //===----------------------------------------------------------------------===//
43 
44 namespace {
45 
46 /// Testing the correctness of some traits.
47 static_assert(
48     llvm::is_detected<OpTrait::has_implicit_terminator_t,
49                       SingleBlockImplicitTerminatorOp>::value,
50     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
51 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
52                   SingleBlockImplicitTerminatorOp>::value,
53               "hasSingleBlockImplicitTerminator does not match "
54               "SingleBlockImplicitTerminatorOp");
55 
56 // Test support for interacting with the AsmPrinter.
57 struct TestOpAsmInterface : public OpAsmDialectInterface {
58   using OpAsmDialectInterface::OpAsmDialectInterface;
59 
60   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
61     StringAttr strAttr = attr.dyn_cast<StringAttr>();
62     if (!strAttr)
63       return AliasResult::NoAlias;
64 
65     // Check the contents of the string attribute to see what the test alias
66     // should be named.
67     Optional<StringRef> aliasName =
68         StringSwitch<Optional<StringRef>>(strAttr.getValue())
69             .Case("alias_test:dot_in_name", StringRef("test.alias"))
70             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
71             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
72             .Case("alias_test:sanitize_conflict_a",
73                   StringRef("test_alias_conflict0"))
74             .Case("alias_test:sanitize_conflict_b",
75                   StringRef("test_alias_conflict0_"))
76             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
77             .Default(llvm::None);
78     if (!aliasName)
79       return AliasResult::NoAlias;
80 
81     os << *aliasName;
82     return AliasResult::FinalAlias;
83   }
84 
85   AliasResult getAlias(Type type, raw_ostream &os) const final {
86     if (auto tupleType = type.dyn_cast<TupleType>()) {
87       if (tupleType.size() > 0 &&
88           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
89             return elemType.isa<SimpleAType>();
90           })) {
91         os << "test_tuple";
92         return AliasResult::FinalAlias;
93       }
94     }
95     if (auto intType = type.dyn_cast<TestIntegerType>()) {
96       if (intType.getSignedness() ==
97               TestIntegerType::SignednessSemantics::Unsigned &&
98           intType.getWidth() == 8) {
99         os << "test_ui8";
100         return AliasResult::FinalAlias;
101       }
102     }
103     return AliasResult::NoAlias;
104   }
105 };
106 
107 struct TestDialectFoldInterface : public DialectFoldInterface {
108   using DialectFoldInterface::DialectFoldInterface;
109 
110   /// Registered hook to check if the given region, which is attached to an
111   /// operation that is *not* isolated from above, should be used when
112   /// materializing constants.
113   bool shouldMaterializeInto(Region *region) const final {
114     // If this is a one region operation, then insert into it.
115     return isa<OneRegionOp>(region->getParentOp());
116   }
117 };
118 
119 /// This class defines the interface for handling inlining with standard
120 /// operations.
121 struct TestInlinerInterface : public DialectInlinerInterface {
122   using DialectInlinerInterface::DialectInlinerInterface;
123 
124   //===--------------------------------------------------------------------===//
125   // Analysis Hooks
126   //===--------------------------------------------------------------------===//
127 
128   bool isLegalToInline(Operation *call, Operation *callable,
129                        bool wouldBeCloned) const final {
130     // Don't allow inlining calls that are marked `noinline`.
131     return !call->hasAttr("noinline");
132   }
133   bool isLegalToInline(Region *, Region *, bool,
134                        BlockAndValueMapping &) const final {
135     // Inlining into test dialect regions is legal.
136     return true;
137   }
138   bool isLegalToInline(Operation *, Region *, bool,
139                        BlockAndValueMapping &) const final {
140     return true;
141   }
142 
143   bool shouldAnalyzeRecursively(Operation *op) const final {
144     // Analyze recursively if this is not a functional region operation, it
145     // froms a separate functional scope.
146     return !isa<FunctionalRegionOp>(op);
147   }
148 
149   //===--------------------------------------------------------------------===//
150   // Transformation Hooks
151   //===--------------------------------------------------------------------===//
152 
153   /// Handle the given inlined terminator by replacing it with a new operation
154   /// as necessary.
155   void handleTerminator(Operation *op,
156                         ArrayRef<Value> valuesToRepl) const final {
157     // Only handle "test.return" here.
158     auto returnOp = dyn_cast<TestReturnOp>(op);
159     if (!returnOp)
160       return;
161 
162     // Replace the values directly with the return operands.
163     assert(returnOp.getNumOperands() == valuesToRepl.size());
164     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
165       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
166   }
167 
168   /// Attempt to materialize a conversion for a type mismatch between a call
169   /// from this dialect, and a callable region. This method should generate an
170   /// operation that takes 'input' as the only operand, and produces a single
171   /// result of 'resultType'. If a conversion can not be generated, nullptr
172   /// should be returned.
173   Operation *materializeCallConversion(OpBuilder &builder, Value input,
174                                        Type resultType,
175                                        Location conversionLoc) const final {
176     // Only allow conversion for i16/i32 types.
177     if (!(resultType.isSignlessInteger(16) ||
178           resultType.isSignlessInteger(32)) ||
179         !(input.getType().isSignlessInteger(16) ||
180           input.getType().isSignlessInteger(32)))
181       return nullptr;
182     return builder.create<TestCastOp>(conversionLoc, resultType, input);
183   }
184 
185   void processInlinedCallBlocks(
186       Operation *call,
187       iterator_range<Region::iterator> inlinedBlocks) const final {
188     if (!isa<ConversionCallOp>(call))
189       return;
190 
191     // Set attributed on all ops in the inlined blocks.
192     for (Block &block : inlinedBlocks) {
193       block.walk([&](Operation *op) {
194         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
195       });
196     }
197   }
198 };
199 
200 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
201 public:
202   TestReductionPatternInterface(Dialect *dialect)
203       : DialectReductionPatternInterface(dialect) {}
204 
205   void populateReductionPatterns(RewritePatternSet &patterns) const final {
206     populateTestReductionPatterns(patterns);
207   }
208 };
209 
210 } // namespace
211 
212 //===----------------------------------------------------------------------===//
213 // Dynamic operations
214 //===----------------------------------------------------------------------===//
215 
216 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
217   return DynamicOpDefinition::get(
218       "dynamic_generic", dialect, [](Operation *op) { return success(); },
219       [](Operation *op) { return success(); });
220 }
221 
222 std::unique_ptr<DynamicOpDefinition>
223 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
224   return DynamicOpDefinition::get(
225       "dynamic_one_operand_two_results", dialect,
226       [](Operation *op) {
227         if (op->getNumOperands() != 1) {
228           op->emitOpError()
229               << "expected 1 operand, but had " << op->getNumOperands();
230           return failure();
231         }
232         if (op->getNumResults() != 2) {
233           op->emitOpError()
234               << "expected 2 results, but had " << op->getNumResults();
235           return failure();
236         }
237         return success();
238       },
239       [](Operation *op) { return success(); });
240 }
241 
242 std::unique_ptr<DynamicOpDefinition>
243 getDynamicCustomParserPrinterOp(TestDialect *dialect) {
244   auto verifier = [](Operation *op) {
245     if (op->getNumOperands() == 0 && op->getNumResults() == 0)
246       return success();
247     op->emitError() << "operation should have no operands and no results";
248     return failure();
249   };
250   auto regionVerifier = [](Operation *op) { return success(); };
251 
252   auto parser = [](OpAsmParser &parser, OperationState &state) {
253     return parser.parseKeyword("custom_keyword");
254   };
255 
256   auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
257     printer << op->getName() << " custom_keyword";
258   };
259 
260   return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
261                                   verifier, regionVerifier, parser, printer);
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // TestDialect
266 //===----------------------------------------------------------------------===//
267 
268 static void testSideEffectOpGetEffect(
269     Operation *op,
270     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
271 
272 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
273 struct TestOpEffectInterfaceFallback
274     : public TestEffectOpInterface::FallbackModel<
275           TestOpEffectInterfaceFallback> {
276   static bool classof(Operation *op) {
277     bool isSupportedOp =
278         op->getName().getStringRef() == "test.unregistered_side_effect_op";
279     assert(isSupportedOp && "Unexpected dispatch");
280     return isSupportedOp;
281   }
282 
283   void
284   getEffects(Operation *op,
285              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
286                  &effects) const {
287     testSideEffectOpGetEffect(op, effects);
288   }
289 };
290 
291 void TestDialect::initialize() {
292   registerAttributes();
293   registerTypes();
294   addOperations<
295 #define GET_OP_LIST
296 #include "TestOps.cpp.inc"
297       >();
298   registerDynamicOp(getDynamicGenericOp(this));
299   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
300   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
301 
302   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
303                 TestInlinerInterface, TestReductionPatternInterface>();
304   allowUnknownOperations();
305 
306   // Instantiate our fallback op interface that we'll use on specific
307   // unregistered op.
308   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
309 }
310 TestDialect::~TestDialect() {
311   delete static_cast<TestOpEffectInterfaceFallback *>(
312       fallbackEffectOpInterfaces);
313 }
314 
315 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
316                                             Type type, Location loc) {
317   return builder.create<TestOpConstant>(loc, type, value);
318 }
319 
320 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
321     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
322     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
323     ::mlir::RegionRange regions,
324     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
325   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
326   return ::mlir::success();
327 }
328 
329 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
330                                                OperationName opName) {
331   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
332       typeID == TypeID::get<TestEffectOpInterface>())
333     return fallbackEffectOpInterfaces;
334   return nullptr;
335 }
336 
337 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
338                                                     NamedAttribute namedAttr) {
339   if (namedAttr.getName() == "test.invalid_attr")
340     return op->emitError() << "invalid to use 'test.invalid_attr'";
341   return success();
342 }
343 
344 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
345                                                     unsigned regionIndex,
346                                                     unsigned argIndex,
347                                                     NamedAttribute namedAttr) {
348   if (namedAttr.getName() == "test.invalid_attr")
349     return op->emitError() << "invalid to use 'test.invalid_attr'";
350   return success();
351 }
352 
353 LogicalResult
354 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
355                                          unsigned resultIndex,
356                                          NamedAttribute namedAttr) {
357   if (namedAttr.getName() == "test.invalid_attr")
358     return op->emitError() << "invalid to use 'test.invalid_attr'";
359   return success();
360 }
361 
362 Optional<Dialect::ParseOpHook>
363 TestDialect::getParseOperationHook(StringRef opName) const {
364   if (opName == "test.dialect_custom_printer") {
365     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
366       return parser.parseKeyword("custom_format");
367     }};
368   }
369   if (opName == "test.dialect_custom_format_fallback") {
370     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
371       return parser.parseKeyword("custom_format_fallback");
372     }};
373   }
374   return None;
375 }
376 
377 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
378 TestDialect::getOperationPrinter(Operation *op) const {
379   StringRef opName = op->getName().getStringRef();
380   if (opName == "test.dialect_custom_printer") {
381     return [](Operation *op, OpAsmPrinter &printer) {
382       printer.getStream() << " custom_format";
383     };
384   }
385   if (opName == "test.dialect_custom_format_fallback") {
386     return [](Operation *op, OpAsmPrinter &printer) {
387       printer.getStream() << " custom_format_fallback";
388     };
389   }
390   return {};
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // TestBranchOp
395 //===----------------------------------------------------------------------===//
396 
397 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
398   assert(index == 0 && "invalid successor index");
399   return SuccessorOperands(getTargetOperandsMutable());
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // TestProducingBranchOp
404 //===----------------------------------------------------------------------===//
405 
406 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
407   assert(index <= 1 && "invalid successor index");
408   if (index == 1)
409     return SuccessorOperands(getFirstOperandsMutable());
410   return SuccessorOperands(getSecondOperandsMutable());
411 }
412 
413 //===----------------------------------------------------------------------===//
414 // TestProducingBranchOp
415 //===----------------------------------------------------------------------===//
416 
417 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
418   assert(index <= 1 && "invalid successor index");
419   if (index == 0)
420     return SuccessorOperands(0, getSuccessOperandsMutable());
421   return SuccessorOperands(1, getErrorOperandsMutable());
422 }
423 
424 //===----------------------------------------------------------------------===//
425 // TestDialectCanonicalizerOp
426 //===----------------------------------------------------------------------===//
427 
428 static LogicalResult
429 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
430                                PatternRewriter &rewriter) {
431   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
432       op, rewriter.getI32IntegerAttr(42));
433   return success();
434 }
435 
436 void TestDialect::getCanonicalizationPatterns(
437     RewritePatternSet &results) const {
438   results.add(&dialectCanonicalizationPattern);
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // TestCallOp
443 //===----------------------------------------------------------------------===//
444 
445 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
446   // Check that the callee attribute was specified.
447   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
448   if (!fnAttr)
449     return emitOpError("requires a 'callee' symbol reference attribute");
450   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
451     return emitOpError() << "'" << fnAttr.getValue()
452                          << "' does not reference a valid function";
453   return success();
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // TestFoldToCallOp
458 //===----------------------------------------------------------------------===//
459 
460 namespace {
461 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
462   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
463 
464   LogicalResult matchAndRewrite(FoldToCallOp op,
465                                 PatternRewriter &rewriter) const override {
466     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
467                                               op.getCalleeAttr(), ValueRange());
468     return success();
469   }
470 };
471 } // namespace
472 
473 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
474                                                MLIRContext *context) {
475   results.add<FoldToCallOpPattern>(context);
476 }
477 
478 //===----------------------------------------------------------------------===//
479 // Test Format* operations
480 //===----------------------------------------------------------------------===//
481 
482 //===----------------------------------------------------------------------===//
483 // Parsing
484 
485 static ParseResult parseCustomOptionalOperand(
486     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
487   if (succeeded(parser.parseOptionalLParen())) {
488     optOperand.emplace();
489     if (parser.parseOperand(*optOperand) || parser.parseRParen())
490       return failure();
491   }
492   return success();
493 }
494 
495 static ParseResult parseCustomDirectiveOperands(
496     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
497     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
498     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
499   if (parser.parseOperand(operand))
500     return failure();
501   if (succeeded(parser.parseOptionalComma())) {
502     optOperand.emplace();
503     if (parser.parseOperand(*optOperand))
504       return failure();
505   }
506   if (parser.parseArrow() || parser.parseLParen() ||
507       parser.parseOperandList(varOperands) || parser.parseRParen())
508     return failure();
509   return success();
510 }
511 static ParseResult
512 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
513                             Type &optOperandType,
514                             SmallVectorImpl<Type> &varOperandTypes) {
515   if (parser.parseColon())
516     return failure();
517 
518   if (parser.parseType(operandType))
519     return failure();
520   if (succeeded(parser.parseOptionalComma())) {
521     if (parser.parseType(optOperandType))
522       return failure();
523   }
524   if (parser.parseArrow() || parser.parseLParen() ||
525       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
526     return failure();
527   return success();
528 }
529 static ParseResult
530 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
531                                  Type optOperandType,
532                                  const SmallVectorImpl<Type> &varOperandTypes) {
533   if (parser.parseKeyword("type_refs_capture"))
534     return failure();
535 
536   Type operandType2, optOperandType2;
537   SmallVector<Type, 1> varOperandTypes2;
538   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
539                                   varOperandTypes2))
540     return failure();
541 
542   if (operandType != operandType2 || optOperandType != optOperandType2 ||
543       varOperandTypes != varOperandTypes2)
544     return failure();
545 
546   return success();
547 }
548 static ParseResult parseCustomDirectiveOperandsAndTypes(
549     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
550     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
551     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
552     Type &operandType, Type &optOperandType,
553     SmallVectorImpl<Type> &varOperandTypes) {
554   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
555       parseCustomDirectiveResults(parser, operandType, optOperandType,
556                                   varOperandTypes))
557     return failure();
558   return success();
559 }
560 static ParseResult parseCustomDirectiveRegions(
561     OpAsmParser &parser, Region &region,
562     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
563   if (parser.parseRegion(region))
564     return failure();
565   if (failed(parser.parseOptionalComma()))
566     return success();
567   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
568   if (parser.parseRegion(*varRegion))
569     return failure();
570   varRegions.emplace_back(std::move(varRegion));
571   return success();
572 }
573 static ParseResult
574 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
575                                SmallVectorImpl<Block *> &varSuccessors) {
576   if (parser.parseSuccessor(successor))
577     return failure();
578   if (failed(parser.parseOptionalComma()))
579     return success();
580   Block *varSuccessor;
581   if (parser.parseSuccessor(varSuccessor))
582     return failure();
583   varSuccessors.append(2, varSuccessor);
584   return success();
585 }
586 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
587                                                   IntegerAttr &attr,
588                                                   IntegerAttr &optAttr) {
589   if (parser.parseAttribute(attr))
590     return failure();
591   if (succeeded(parser.parseOptionalComma())) {
592     if (parser.parseAttribute(optAttr))
593       return failure();
594   }
595   return success();
596 }
597 
598 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
599                                                 NamedAttrList &attrs) {
600   return parser.parseOptionalAttrDict(attrs);
601 }
602 static ParseResult parseCustomDirectiveOptionalOperandRef(
603     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
604   int64_t operandCount = 0;
605   if (parser.parseInteger(operandCount))
606     return failure();
607   bool expectedOptionalOperand = operandCount == 0;
608   return success(expectedOptionalOperand != optOperand.hasValue());
609 }
610 
611 //===----------------------------------------------------------------------===//
612 // Printing
613 
614 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
615                                        Value optOperand) {
616   if (optOperand)
617     printer << "(" << optOperand << ") ";
618 }
619 
620 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
621                                          Value operand, Value optOperand,
622                                          OperandRange varOperands) {
623   printer << operand;
624   if (optOperand)
625     printer << ", " << optOperand;
626   printer << " -> (" << varOperands << ")";
627 }
628 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
629                                         Type operandType, Type optOperandType,
630                                         TypeRange varOperandTypes) {
631   printer << " : " << operandType;
632   if (optOperandType)
633     printer << ", " << optOperandType;
634   printer << " -> (" << varOperandTypes << ")";
635 }
636 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
637                                              Operation *op, Type operandType,
638                                              Type optOperandType,
639                                              TypeRange varOperandTypes) {
640   printer << " type_refs_capture ";
641   printCustomDirectiveResults(printer, op, operandType, optOperandType,
642                               varOperandTypes);
643 }
644 static void printCustomDirectiveOperandsAndTypes(
645     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
646     OperandRange varOperands, Type operandType, Type optOperandType,
647     TypeRange varOperandTypes) {
648   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
649   printCustomDirectiveResults(printer, op, operandType, optOperandType,
650                               varOperandTypes);
651 }
652 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
653                                         Region &region,
654                                         MutableArrayRef<Region> varRegions) {
655   printer.printRegion(region);
656   if (!varRegions.empty()) {
657     printer << ", ";
658     for (Region &region : varRegions)
659       printer.printRegion(region);
660   }
661 }
662 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
663                                            Block *successor,
664                                            SuccessorRange varSuccessors) {
665   printer << successor;
666   if (!varSuccessors.empty())
667     printer << ", " << varSuccessors.front();
668 }
669 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
670                                            Attribute attribute,
671                                            Attribute optAttribute) {
672   printer << attribute;
673   if (optAttribute)
674     printer << ", " << optAttribute;
675 }
676 
677 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
678                                          DictionaryAttr attrs) {
679   printer.printOptionalAttrDict(attrs.getValue());
680 }
681 
682 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
683                                                    Operation *op,
684                                                    Value optOperand) {
685   printer << (optOperand ? "1" : "0");
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // Test IsolatedRegionOp - parse passthrough region arguments.
690 //===----------------------------------------------------------------------===//
691 
692 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
693                                     OperationState &result) {
694   OpAsmParser::UnresolvedOperand argInfo;
695   Type argType = parser.getBuilder().getIndexType();
696 
697   // Parse the input operand.
698   if (parser.parseOperand(argInfo) ||
699       parser.resolveOperand(argInfo, argType, result.operands))
700     return failure();
701 
702   // Parse the body region, and reuse the operand info as the argument info.
703   Region *body = result.addRegion();
704   return parser.parseRegion(*body, argInfo, argType,
705                             /*enableNameShadowing=*/true);
706 }
707 
708 void IsolatedRegionOp::print(OpAsmPrinter &p) {
709   p << "test.isolated_region ";
710   p.printOperand(getOperand());
711   p.shadowRegionArgs(getRegion(), getOperand());
712   p << ' ';
713   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
714 }
715 
716 //===----------------------------------------------------------------------===//
717 // Test SSACFGRegionOp
718 //===----------------------------------------------------------------------===//
719 
720 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
721   return RegionKind::SSACFG;
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // Test GraphRegionOp
726 //===----------------------------------------------------------------------===//
727 
728 ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) {
729   // Parse the body region, and reuse the operand info as the argument info.
730   Region *body = result.addRegion();
731   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
732 }
733 
734 void GraphRegionOp::print(OpAsmPrinter &p) {
735   p << "test.graph_region ";
736   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
737 }
738 
739 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
740   return RegionKind::Graph;
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // Test AffineScopeOp
745 //===----------------------------------------------------------------------===//
746 
747 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
748   // Parse the body region, and reuse the operand info as the argument info.
749   Region *body = result.addRegion();
750   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
751 }
752 
753 void AffineScopeOp::print(OpAsmPrinter &p) {
754   p << "test.affine_scope ";
755   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
756 }
757 
758 //===----------------------------------------------------------------------===//
759 // Test parser.
760 //===----------------------------------------------------------------------===//
761 
762 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
763                                          OperationState &result) {
764   if (parser.parseOptionalColon())
765     return success();
766   uint64_t numResults;
767   if (parser.parseInteger(numResults))
768     return failure();
769 
770   IndexType type = parser.getBuilder().getIndexType();
771   for (unsigned i = 0; i < numResults; ++i)
772     result.addTypes(type);
773   return success();
774 }
775 
776 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
777   if (unsigned numResults = getNumResults())
778     p << " : " << numResults;
779 }
780 
781 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
782                                          OperationState &result) {
783   StringRef keyword;
784   if (parser.parseKeyword(&keyword))
785     return failure();
786   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
787   return success();
788 }
789 
790 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
791 
792 //===----------------------------------------------------------------------===//
793 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
794 
795 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
796                                     OperationState &result) {
797   if (parser.parseKeyword("wraps"))
798     return failure();
799 
800   // Parse the wrapped op in a region
801   Region &body = *result.addRegion();
802   body.push_back(new Block);
803   Block &block = body.back();
804   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
805   if (!wrappedOp)
806     return failure();
807 
808   // Create a return terminator in the inner region, pass as operand to the
809   // terminator the returned values from the wrapped operation.
810   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
811   OpBuilder builder(parser.getContext());
812   builder.setInsertionPointToEnd(&block);
813   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
814 
815   // Get the results type for the wrapping op from the terminator operands.
816   Operation &returnOp = body.back().back();
817   result.types.append(returnOp.operand_type_begin(),
818                       returnOp.operand_type_end());
819 
820   // Use the location of the wrapped op for the "test.wrapping_region" op.
821   result.location = wrappedOp->getLoc();
822 
823   return success();
824 }
825 
826 void WrappingRegionOp::print(OpAsmPrinter &p) {
827   p << " wraps ";
828   p.printGenericOp(&getRegion().front().front());
829 }
830 
831 //===----------------------------------------------------------------------===//
832 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
833 //   parseGenericOperationAfterOpName
834 //   parseCustomOperationName
835 //===----------------------------------------------------------------------===//
836 
837 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
838                                          OperationState &result) {
839 
840   SMLoc loc = parser.getCurrentLocation();
841   Location currLocation = parser.getEncodedSourceLoc(loc);
842 
843   // Parse the operands.
844   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
845   if (parser.parseOperandList(operands))
846     return failure();
847 
848   // Check if we are parsing the pretty-printed version
849   //  test.pretty_printed_region start <inner-op> end : <functional-type>
850   // Else fallback to parsing the "non pretty-printed" version.
851   if (!succeeded(parser.parseOptionalKeyword("start")))
852     return parser.parseGenericOperationAfterOpName(
853         result, llvm::makeArrayRef(operands));
854 
855   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
856   if (failed(parseOpNameInfo))
857     return failure();
858 
859   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
860 
861   FunctionType opFntype;
862   Optional<Location> explicitLoc;
863   if (parser.parseKeyword("end") || parser.parseColon() ||
864       parser.parseType(opFntype) ||
865       parser.parseOptionalLocationSpecifier(explicitLoc))
866     return failure();
867 
868   // If location of the op is explicitly provided, then use it; Else use
869   // the parser's current location.
870   Location opLoc = explicitLoc.getValueOr(currLocation);
871 
872   // Derive the SSA-values for op's operands.
873   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
874                              result.operands))
875     return failure();
876 
877   // Add a region for op.
878   Region &region = *result.addRegion();
879 
880   // Create a basic-block inside op's region.
881   Block &block = region.emplaceBlock();
882 
883   // Create and insert an "inner-op" operation in the block.
884   // Just for testing purposes, we can assume that inner op is a binary op with
885   // result and operand types all same as the test-op's first operand.
886   Type innerOpType = opFntype.getInput(0);
887   Value lhs = block.addArgument(innerOpType, opLoc);
888   Value rhs = block.addArgument(innerOpType, opLoc);
889 
890   OpBuilder builder(parser.getBuilder().getContext());
891   builder.setInsertionPointToStart(&block);
892 
893   Operation *innerOp =
894       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
895 
896   // Insert a return statement in the block returning the inner-op's result.
897   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
898 
899   // Populate the op operation-state with result-type and location.
900   result.addTypes(opFntype.getResults());
901   result.location = innerOp->getLoc();
902 
903   return success();
904 }
905 
906 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
907   p << ' ';
908   p.printOperands(getOperands());
909 
910   Operation &innerOp = getRegion().front().front();
911   // Assuming that region has a single non-terminator inner-op, if the inner-op
912   // meets some criteria (which in this case is a simple one  based on the name
913   // of inner-op), then we can print the entire region in a succinct way.
914   // Here we assume that the prototype of "special.op" can be trivially derived
915   // while parsing it back.
916   if (innerOp.getName().getStringRef().equals("special.op")) {
917     p << " start special.op end";
918   } else {
919     p << " (";
920     p.printRegion(getRegion());
921     p << ")";
922   }
923 
924   p << " : ";
925   p.printFunctionalType(*this);
926 }
927 
928 //===----------------------------------------------------------------------===//
929 // Test PolyForOp - parse list of region arguments.
930 //===----------------------------------------------------------------------===//
931 
932 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
933   SmallVector<OpAsmParser::UnresolvedOperand, 4> ivsInfo;
934   // Parse list of region arguments without a delimiter.
935   if (parser.parseOperandList(ivsInfo, OpAsmParser::Delimiter::None,
936                               /*allowResultNumber=*/false))
937     return failure();
938 
939   // Parse the body region.
940   Region *body = result.addRegion();
941   auto &builder = parser.getBuilder();
942   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
943   return parser.parseRegion(*body, ivsInfo, argTypes);
944 }
945 
946 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
947 
948 void PolyForOp::getAsmBlockArgumentNames(Region &region,
949                                          OpAsmSetValueNameFn setNameFn) {
950   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
951   if (!arrayAttr)
952     return;
953   auto args = getRegion().front().getArguments();
954   auto e = std::min(arrayAttr.size(), args.size());
955   for (unsigned i = 0; i < e; ++i) {
956     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
957       setNameFn(args[i], strAttr.getValue());
958   }
959 }
960 
961 //===----------------------------------------------------------------------===//
962 // Test removing op with inner ops.
963 //===----------------------------------------------------------------------===//
964 
965 namespace {
966 struct TestRemoveOpWithInnerOps
967     : public OpRewritePattern<TestOpWithRegionPattern> {
968   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
969 
970   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
971 
972   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
973                                 PatternRewriter &rewriter) const override {
974     rewriter.eraseOp(op);
975     return success();
976   }
977 };
978 } // namespace
979 
980 void TestOpWithRegionPattern::getCanonicalizationPatterns(
981     RewritePatternSet &results, MLIRContext *context) {
982   results.add<TestRemoveOpWithInnerOps>(context);
983 }
984 
985 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
986   return getOperand();
987 }
988 
989 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
990   return getValue();
991 }
992 
993 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
994     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
995   for (Value input : this->getOperands()) {
996     results.push_back(input);
997   }
998   return success();
999 }
1000 
1001 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
1002   assert(operands.size() == 1);
1003   if (operands.front()) {
1004     (*this)->setAttr("attr", operands.front());
1005     return getResult();
1006   }
1007   return {};
1008 }
1009 
1010 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
1011   return getOperand();
1012 }
1013 
1014 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
1015     MLIRContext *, Optional<Location> location, ValueRange operands,
1016     DictionaryAttr attributes, RegionRange regions,
1017     SmallVectorImpl<Type> &inferredReturnTypes) {
1018   if (operands[0].getType() != operands[1].getType()) {
1019     return emitOptionalError(location, "operand type mismatch ",
1020                              operands[0].getType(), " vs ",
1021                              operands[1].getType());
1022   }
1023   inferredReturnTypes.assign({operands[0].getType()});
1024   return success();
1025 }
1026 
1027 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
1028     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1029     DictionaryAttr attributes, RegionRange regions,
1030     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1031   // Create return type consisting of the last element of the first operand.
1032   auto operandType = operands.front().getType();
1033   auto sval = operandType.dyn_cast<ShapedType>();
1034   if (!sval) {
1035     return emitOptionalError(location, "only shaped type operands allowed");
1036   }
1037   int64_t dim =
1038       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1039   auto type = IntegerType::get(context, 17);
1040   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1041   return success();
1042 }
1043 
1044 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1045     OpBuilder &builder, ValueRange operands,
1046     llvm::SmallVectorImpl<Value> &shapes) {
1047   shapes = SmallVector<Value, 1>{
1048       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1049   return success();
1050 }
1051 
1052 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1053     OpBuilder &builder, ValueRange operands,
1054     llvm::SmallVectorImpl<Value> &shapes) {
1055   Location loc = getLoc();
1056   shapes.reserve(operands.size());
1057   for (Value operand : llvm::reverse(operands)) {
1058     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1059     auto currShape = llvm::to_vector<4>(
1060         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1061           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1062         }));
1063     shapes.push_back(builder.create<tensor::FromElementsOp>(
1064         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1065         currShape));
1066   }
1067   return success();
1068 }
1069 
1070 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1071     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1072   Location loc = getLoc();
1073   shapes.reserve(getNumOperands());
1074   for (Value operand : llvm::reverse(getOperands())) {
1075     auto currShape = llvm::to_vector<4>(llvm::map_range(
1076         llvm::seq<int64_t>(
1077             0, operand.getType().cast<RankedTensorType>().getRank()),
1078         [&](int64_t dim) -> Value {
1079           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1080         }));
1081     shapes.emplace_back(std::move(currShape));
1082   }
1083   return success();
1084 }
1085 
1086 //===----------------------------------------------------------------------===//
1087 // Test SideEffect interfaces
1088 //===----------------------------------------------------------------------===//
1089 
1090 namespace {
1091 /// A test resource for side effects.
1092 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1093   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1094 
1095   StringRef getName() final { return "<Test>"; }
1096 };
1097 } // namespace
1098 
1099 static void testSideEffectOpGetEffect(
1100     Operation *op,
1101     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1102         &effects) {
1103   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1104   if (!effectsAttr)
1105     return;
1106 
1107   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1108 }
1109 
1110 void SideEffectOp::getEffects(
1111     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1112   // Check for an effects attribute on the op instance.
1113   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1114   if (!effectsAttr)
1115     return;
1116 
1117   // If there is one, it is an array of dictionary attributes that hold
1118   // information on the effects of this operation.
1119   for (Attribute element : effectsAttr) {
1120     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1121 
1122     // Get the specific memory effect.
1123     MemoryEffects::Effect *effect =
1124         StringSwitch<MemoryEffects::Effect *>(
1125             effectElement.get("effect").cast<StringAttr>().getValue())
1126             .Case("allocate", MemoryEffects::Allocate::get())
1127             .Case("free", MemoryEffects::Free::get())
1128             .Case("read", MemoryEffects::Read::get())
1129             .Case("write", MemoryEffects::Write::get());
1130 
1131     // Check for a non-default resource to use.
1132     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1133     if (effectElement.get("test_resource"))
1134       resource = TestResource::get();
1135 
1136     // Check for a result to affect.
1137     if (effectElement.get("on_result"))
1138       effects.emplace_back(effect, getResult(), resource);
1139     else if (Attribute ref = effectElement.get("on_reference"))
1140       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1141     else
1142       effects.emplace_back(effect, resource);
1143   }
1144 }
1145 
1146 void SideEffectOp::getEffects(
1147     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1148   testSideEffectOpGetEffect(getOperation(), effects);
1149 }
1150 
1151 //===----------------------------------------------------------------------===//
1152 // StringAttrPrettyNameOp
1153 //===----------------------------------------------------------------------===//
1154 
1155 // This op has fancy handling of its SSA result name.
1156 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1157                                           OperationState &result) {
1158   // Add the result types.
1159   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1160     result.addTypes(parser.getBuilder().getIntegerType(32));
1161 
1162   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1163     return failure();
1164 
1165   // If the attribute dictionary contains no 'names' attribute, infer it from
1166   // the SSA name (if specified).
1167   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1168     return attr.getName() == "names";
1169   });
1170 
1171   // If there was no name specified, check to see if there was a useful name
1172   // specified in the asm file.
1173   if (hadNames || parser.getNumResults() == 0)
1174     return success();
1175 
1176   SmallVector<StringRef, 4> names;
1177   auto *context = result.getContext();
1178 
1179   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1180     auto resultName = parser.getResultName(i);
1181     StringRef nameStr;
1182     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1183       nameStr = resultName.first;
1184 
1185     names.push_back(nameStr);
1186   }
1187 
1188   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1189   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1190   return success();
1191 }
1192 
1193 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1194   // Note that we only need to print the "name" attribute if the asmprinter
1195   // result name disagrees with it.  This can happen in strange cases, e.g.
1196   // when there are conflicts.
1197   bool namesDisagree = getNames().size() != getNumResults();
1198 
1199   SmallString<32> resultNameStr;
1200   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1201     resultNameStr.clear();
1202     llvm::raw_svector_ostream tmpStream(resultNameStr);
1203     p.printOperand(getResult(i), tmpStream);
1204 
1205     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1206     if (!expectedName ||
1207         tmpStream.str().drop_front() != expectedName.getValue()) {
1208       namesDisagree = true;
1209     }
1210   }
1211 
1212   if (namesDisagree)
1213     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1214   else
1215     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1216 }
1217 
1218 // We set the SSA name in the asm syntax to the contents of the name
1219 // attribute.
1220 void StringAttrPrettyNameOp::getAsmResultNames(
1221     function_ref<void(Value, StringRef)> setNameFn) {
1222 
1223   auto value = getNames();
1224   for (size_t i = 0, e = value.size(); i != e; ++i)
1225     if (auto str = value[i].dyn_cast<StringAttr>())
1226       if (!str.getValue().empty())
1227         setNameFn(getResult(i), str.getValue());
1228 }
1229 
1230 void CustomResultsNameOp::getAsmResultNames(
1231     function_ref<void(Value, StringRef)> setNameFn) {
1232   ArrayAttr value = getNames();
1233   for (size_t i = 0, e = value.size(); i != e; ++i)
1234     if (auto str = value[i].dyn_cast<StringAttr>())
1235       if (!str.getValue().empty())
1236         setNameFn(getResult(i), str.getValue());
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // ResultTypeWithTraitOp
1241 //===----------------------------------------------------------------------===//
1242 
1243 LogicalResult ResultTypeWithTraitOp::verify() {
1244   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1245     return success();
1246   return emitError("result type should have trait 'TestTypeTrait'");
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // AttrWithTraitOp
1251 //===----------------------------------------------------------------------===//
1252 
1253 LogicalResult AttrWithTraitOp::verify() {
1254   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1255     return success();
1256   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1257 }
1258 
1259 //===----------------------------------------------------------------------===//
1260 // RegionIfOp
1261 //===----------------------------------------------------------------------===//
1262 
1263 void RegionIfOp::print(OpAsmPrinter &p) {
1264   p << " ";
1265   p.printOperands(getOperands());
1266   p << ": " << getOperandTypes();
1267   p.printArrowTypeList(getResultTypes());
1268   p << " then ";
1269   p.printRegion(getThenRegion(),
1270                 /*printEntryBlockArgs=*/true,
1271                 /*printBlockTerminators=*/true);
1272   p << " else ";
1273   p.printRegion(getElseRegion(),
1274                 /*printEntryBlockArgs=*/true,
1275                 /*printBlockTerminators=*/true);
1276   p << " join ";
1277   p.printRegion(getJoinRegion(),
1278                 /*printEntryBlockArgs=*/true,
1279                 /*printBlockTerminators=*/true);
1280 }
1281 
1282 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1283   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1284   SmallVector<Type, 2> operandTypes;
1285 
1286   result.regions.reserve(3);
1287   Region *thenRegion = result.addRegion();
1288   Region *elseRegion = result.addRegion();
1289   Region *joinRegion = result.addRegion();
1290 
1291   // Parse operand, type and arrow type lists.
1292   if (parser.parseOperandList(operandInfos) ||
1293       parser.parseColonTypeList(operandTypes) ||
1294       parser.parseArrowTypeList(result.types))
1295     return failure();
1296 
1297   // Parse all attached regions.
1298   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1299       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1300       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1301     return failure();
1302 
1303   return parser.resolveOperands(operandInfos, operandTypes,
1304                                 parser.getCurrentLocation(), result.operands);
1305 }
1306 
1307 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
1308   assert(index < 2 && "invalid region index");
1309   return getOperands();
1310 }
1311 
1312 void RegionIfOp::getSuccessorRegions(
1313     Optional<unsigned> index, ArrayRef<Attribute> operands,
1314     SmallVectorImpl<RegionSuccessor> &regions) {
1315   // We always branch to the join region.
1316   if (index.hasValue()) {
1317     if (index.getValue() < 2)
1318       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1319     else
1320       regions.push_back(RegionSuccessor(getResults()));
1321     return;
1322   }
1323 
1324   // The then and else regions are the entry regions of this op.
1325   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1326   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1327 }
1328 
1329 void RegionIfOp::getRegionInvocationBounds(
1330     ArrayRef<Attribute> operands,
1331     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1332   // Each region is invoked at most once.
1333   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // AnyCondOp
1338 //===----------------------------------------------------------------------===//
1339 
1340 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1341                                     ArrayRef<Attribute> operands,
1342                                     SmallVectorImpl<RegionSuccessor> &regions) {
1343   // The parent op branches into the only region, and the region branches back
1344   // to the parent op.
1345   if (index)
1346     regions.emplace_back(&getRegion());
1347   else
1348     regions.emplace_back(getResults());
1349 }
1350 
1351 void AnyCondOp::getRegionInvocationBounds(
1352     ArrayRef<Attribute> operands,
1353     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1354   invocationBounds.emplace_back(1, 1);
1355 }
1356 
1357 //===----------------------------------------------------------------------===//
1358 // SingleNoTerminatorCustomAsmOp
1359 //===----------------------------------------------------------------------===//
1360 
1361 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1362                                                  OperationState &state) {
1363   Region *body = state.addRegion();
1364   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1365     return failure();
1366   return success();
1367 }
1368 
1369 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1370   printer.printRegion(
1371       getRegion(), /*printEntryBlockArgs=*/false,
1372       // This op has a single block without terminators. But explicitly mark
1373       // as not printing block terminators for testing.
1374       /*printBlockTerminators=*/false);
1375 }
1376 
1377 //===----------------------------------------------------------------------===//
1378 // TestVerifiersOp
1379 //===----------------------------------------------------------------------===//
1380 
1381 LogicalResult TestVerifiersOp::verify() {
1382   if (!getRegion().hasOneBlock())
1383     return emitOpError("`hasOneBlock` trait hasn't been verified");
1384 
1385   Operation *definingOp = getInput().getDefiningOp();
1386   if (definingOp && failed(mlir::verify(definingOp)))
1387     return emitOpError("operand hasn't been verified");
1388 
1389   emitRemark("success run of verifier");
1390 
1391   return success();
1392 }
1393 
1394 LogicalResult TestVerifiersOp::verifyRegions() {
1395   if (!getRegion().hasOneBlock())
1396     return emitOpError("`hasOneBlock` trait hasn't been verified");
1397 
1398   for (Block &block : getRegion())
1399     for (Operation &op : block)
1400       if (failed(mlir::verify(&op)))
1401         return emitOpError("nested op hasn't been verified");
1402 
1403   emitRemark("success run of region verifier");
1404 
1405   return success();
1406 }
1407 
1408 #include "TestOpEnums.cpp.inc"
1409 #include "TestOpInterfaces.cpp.inc"
1410 #include "TestOpStructs.cpp.inc"
1411 #include "TestTypeInterfaces.cpp.inc"
1412 
1413 #define GET_OP_CLASSES
1414 #include "TestOps.cpp.inc"
1415