1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/ExtensibleDialect.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Verifier.h"
28 #include "mlir/Interfaces/InferIntRangeInterface.h"
29 #include "mlir/Reducer/ReductionPatternInterface.h"
30 #include "mlir/Transforms/FoldUtils.h"
31 #include "mlir/Transforms/InliningUtils.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/ADT/StringSwitch.h"
35 
36 // Include this before the using namespace lines below to
37 // test that we don't have namespace dependencies.
38 #include "TestOpsDialect.cpp.inc"
39 
40 using namespace mlir;
41 using namespace test;
42 
43 void test::registerTestDialect(DialectRegistry &registry) {
44   registry.insert<TestDialect>();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // External Elements Data
49 //===----------------------------------------------------------------------===//
50 
51 ArrayRef<uint64_t> TestExternalElementsData::getData() const {
52   ArrayRef<char> data = AsmResourceBlob::getData();
53   return ArrayRef<uint64_t>((const uint64_t *)data.data(),
54                             data.size() / sizeof(uint64_t));
55 }
56 
57 TestExternalElementsData
58 TestExternalElementsData::allocate(size_t numElements) {
59   return TestExternalElementsData(
60       llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements),
61       [](const uint64_t *data, size_t) { delete[] data; },
62       /*dataIsMutable=*/true);
63 }
64 
65 const TestExternalElementsData *
66 TestExternalElementsDataManager::getData(StringRef name) const {
67   auto it = dataMap.find(name);
68   return it != dataMap.end() ? &*it->second : nullptr;
69 }
70 
71 std::pair<TestExternalElementsDataManager::DataMap::iterator, bool>
72 TestExternalElementsDataManager::insert(StringRef name) {
73   auto it = dataMap.try_emplace(name, nullptr);
74   if (it.second)
75     return it;
76 
77   llvm::SmallString<32> nameStorage(name);
78   nameStorage.push_back('_');
79   size_t nameCounter = 1;
80   do {
81     nameStorage += std::to_string(nameCounter++);
82     auto it = dataMap.try_emplace(nameStorage, nullptr);
83     if (it.second)
84       return it;
85     nameStorage.resize(name.size() + 1);
86   } while (true);
87 }
88 
89 void TestExternalElementsDataManager::setData(StringRef name,
90                                               TestExternalElementsData &&data) {
91   auto it = dataMap.find(name);
92   assert(it != dataMap.end() && "data not registered");
93   it->second = std::make_unique<TestExternalElementsData>(std::move(data));
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // TestDialect Interfaces
98 //===----------------------------------------------------------------------===//
99 
100 namespace {
101 
102 /// Testing the correctness of some traits.
103 static_assert(
104     llvm::is_detected<OpTrait::has_implicit_terminator_t,
105                       SingleBlockImplicitTerminatorOp>::value,
106     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
107 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
108                   SingleBlockImplicitTerminatorOp>::value,
109               "hasSingleBlockImplicitTerminator does not match "
110               "SingleBlockImplicitTerminatorOp");
111 
112 // Test support for interacting with the AsmPrinter.
113 struct TestOpAsmInterface : public OpAsmDialectInterface {
114   using OpAsmDialectInterface::OpAsmDialectInterface;
115 
116   //===------------------------------------------------------------------===//
117   // Aliases
118   //===------------------------------------------------------------------===//
119 
120   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
121     StringAttr strAttr = attr.dyn_cast<StringAttr>();
122     if (!strAttr)
123       return AliasResult::NoAlias;
124 
125     // Check the contents of the string attribute to see what the test alias
126     // should be named.
127     Optional<StringRef> aliasName =
128         StringSwitch<Optional<StringRef>>(strAttr.getValue())
129             .Case("alias_test:dot_in_name", StringRef("test.alias"))
130             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
131             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
132             .Case("alias_test:sanitize_conflict_a",
133                   StringRef("test_alias_conflict0"))
134             .Case("alias_test:sanitize_conflict_b",
135                   StringRef("test_alias_conflict0_"))
136             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
137             .Default(llvm::None);
138     if (!aliasName)
139       return AliasResult::NoAlias;
140 
141     os << *aliasName;
142     return AliasResult::FinalAlias;
143   }
144 
145   AliasResult getAlias(Type type, raw_ostream &os) const final {
146     if (auto tupleType = type.dyn_cast<TupleType>()) {
147       if (tupleType.size() > 0 &&
148           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
149             return elemType.isa<SimpleAType>();
150           })) {
151         os << "test_tuple";
152         return AliasResult::FinalAlias;
153       }
154     }
155     if (auto intType = type.dyn_cast<TestIntegerType>()) {
156       if (intType.getSignedness() ==
157               TestIntegerType::SignednessSemantics::Unsigned &&
158           intType.getWidth() == 8) {
159         os << "test_ui8";
160         return AliasResult::FinalAlias;
161       }
162     }
163     return AliasResult::NoAlias;
164   }
165 
166   //===------------------------------------------------------------------===//
167   // Resources
168   //===------------------------------------------------------------------===//
169 
170   std::string
171   getResourceKey(const AsmDialectResourceHandle &handle) const override {
172     return cast<TestExternalElementsDataHandle>(handle).getKey().str();
173   }
174 
175   FailureOr<AsmDialectResourceHandle>
176   declareResource(StringRef key) const final {
177     TestDialect *dialect = cast<TestDialect>(getDialect());
178     TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
179 
180     // Resolve the reference by inserting a new entry into the manager.
181     auto it = mgr.insert(key).first;
182     return TestExternalElementsDataHandle(&*it, dialect);
183   }
184 
185   LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
186     TestDialect *dialect = cast<TestDialect>(getDialect());
187     TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
188 
189     // The resource entries are external constant data.
190     auto blobAllocFn = [](unsigned size, unsigned align) {
191       assert(align == alignof(uint64_t) && "unexpected data alignment");
192       return TestExternalElementsData::allocate(size / sizeof(uint64_t));
193     };
194     FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn);
195     if (failed(blob))
196       return failure();
197 
198     mgr.setData(entry.getKey(), std::move(*blob));
199     return success();
200   }
201 
202   void
203   buildResources(Operation *op,
204                  const SetVector<AsmDialectResourceHandle> &referencedResources,
205                  AsmResourceBuilder &provider) const final {
206     for (const AsmDialectResourceHandle &handle : referencedResources) {
207       const auto &testHandle = cast<TestExternalElementsDataHandle>(handle);
208       provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData());
209     }
210   }
211 };
212 
213 struct TestDialectFoldInterface : public DialectFoldInterface {
214   using DialectFoldInterface::DialectFoldInterface;
215 
216   /// Registered hook to check if the given region, which is attached to an
217   /// operation that is *not* isolated from above, should be used when
218   /// materializing constants.
219   bool shouldMaterializeInto(Region *region) const final {
220     // If this is a one region operation, then insert into it.
221     return isa<OneRegionOp>(region->getParentOp());
222   }
223 };
224 
225 /// This class defines the interface for handling inlining with standard
226 /// operations.
227 struct TestInlinerInterface : public DialectInlinerInterface {
228   using DialectInlinerInterface::DialectInlinerInterface;
229 
230   //===--------------------------------------------------------------------===//
231   // Analysis Hooks
232   //===--------------------------------------------------------------------===//
233 
234   bool isLegalToInline(Operation *call, Operation *callable,
235                        bool wouldBeCloned) const final {
236     // Don't allow inlining calls that are marked `noinline`.
237     return !call->hasAttr("noinline");
238   }
239   bool isLegalToInline(Region *, Region *, bool,
240                        BlockAndValueMapping &) const final {
241     // Inlining into test dialect regions is legal.
242     return true;
243   }
244   bool isLegalToInline(Operation *, Region *, bool,
245                        BlockAndValueMapping &) const final {
246     return true;
247   }
248 
249   bool shouldAnalyzeRecursively(Operation *op) const final {
250     // Analyze recursively if this is not a functional region operation, it
251     // froms a separate functional scope.
252     return !isa<FunctionalRegionOp>(op);
253   }
254 
255   //===--------------------------------------------------------------------===//
256   // Transformation Hooks
257   //===--------------------------------------------------------------------===//
258 
259   /// Handle the given inlined terminator by replacing it with a new operation
260   /// as necessary.
261   void handleTerminator(Operation *op,
262                         ArrayRef<Value> valuesToRepl) const final {
263     // Only handle "test.return" here.
264     auto returnOp = dyn_cast<TestReturnOp>(op);
265     if (!returnOp)
266       return;
267 
268     // Replace the values directly with the return operands.
269     assert(returnOp.getNumOperands() == valuesToRepl.size());
270     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
271       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
272   }
273 
274   /// Attempt to materialize a conversion for a type mismatch between a call
275   /// from this dialect, and a callable region. This method should generate an
276   /// operation that takes 'input' as the only operand, and produces a single
277   /// result of 'resultType'. If a conversion can not be generated, nullptr
278   /// should be returned.
279   Operation *materializeCallConversion(OpBuilder &builder, Value input,
280                                        Type resultType,
281                                        Location conversionLoc) const final {
282     // Only allow conversion for i16/i32 types.
283     if (!(resultType.isSignlessInteger(16) ||
284           resultType.isSignlessInteger(32)) ||
285         !(input.getType().isSignlessInteger(16) ||
286           input.getType().isSignlessInteger(32)))
287       return nullptr;
288     return builder.create<TestCastOp>(conversionLoc, resultType, input);
289   }
290 
291   void processInlinedCallBlocks(
292       Operation *call,
293       iterator_range<Region::iterator> inlinedBlocks) const final {
294     if (!isa<ConversionCallOp>(call))
295       return;
296 
297     // Set attributed on all ops in the inlined blocks.
298     for (Block &block : inlinedBlocks) {
299       block.walk([&](Operation *op) {
300         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
301       });
302     }
303   }
304 };
305 
306 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
307 public:
308   TestReductionPatternInterface(Dialect *dialect)
309       : DialectReductionPatternInterface(dialect) {}
310 
311   void populateReductionPatterns(RewritePatternSet &patterns) const final {
312     populateTestReductionPatterns(patterns);
313   }
314 };
315 
316 } // namespace
317 
318 //===----------------------------------------------------------------------===//
319 // Dynamic operations
320 //===----------------------------------------------------------------------===//
321 
322 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
323   return DynamicOpDefinition::get(
324       "dynamic_generic", dialect, [](Operation *op) { return success(); },
325       [](Operation *op) { return success(); });
326 }
327 
328 std::unique_ptr<DynamicOpDefinition>
329 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
330   return DynamicOpDefinition::get(
331       "dynamic_one_operand_two_results", dialect,
332       [](Operation *op) {
333         if (op->getNumOperands() != 1) {
334           op->emitOpError()
335               << "expected 1 operand, but had " << op->getNumOperands();
336           return failure();
337         }
338         if (op->getNumResults() != 2) {
339           op->emitOpError()
340               << "expected 2 results, but had " << op->getNumResults();
341           return failure();
342         }
343         return success();
344       },
345       [](Operation *op) { return success(); });
346 }
347 
348 std::unique_ptr<DynamicOpDefinition>
349 getDynamicCustomParserPrinterOp(TestDialect *dialect) {
350   auto verifier = [](Operation *op) {
351     if (op->getNumOperands() == 0 && op->getNumResults() == 0)
352       return success();
353     op->emitError() << "operation should have no operands and no results";
354     return failure();
355   };
356   auto regionVerifier = [](Operation *op) { return success(); };
357 
358   auto parser = [](OpAsmParser &parser, OperationState &state) {
359     return parser.parseKeyword("custom_keyword");
360   };
361 
362   auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
363     printer << op->getName() << " custom_keyword";
364   };
365 
366   return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
367                                   verifier, regionVerifier, parser, printer);
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // TestDialect
372 //===----------------------------------------------------------------------===//
373 
374 static void testSideEffectOpGetEffect(
375     Operation *op,
376     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
377 
378 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
379 struct TestOpEffectInterfaceFallback
380     : public TestEffectOpInterface::FallbackModel<
381           TestOpEffectInterfaceFallback> {
382   static bool classof(Operation *op) {
383     bool isSupportedOp =
384         op->getName().getStringRef() == "test.unregistered_side_effect_op";
385     assert(isSupportedOp && "Unexpected dispatch");
386     return isSupportedOp;
387   }
388 
389   void
390   getEffects(Operation *op,
391              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
392                  &effects) const {
393     testSideEffectOpGetEffect(op, effects);
394   }
395 };
396 
397 void TestDialect::initialize() {
398   registerAttributes();
399   registerTypes();
400   addOperations<
401 #define GET_OP_LIST
402 #include "TestOps.cpp.inc"
403       >();
404   registerDynamicOp(getDynamicGenericOp(this));
405   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
406   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
407 
408   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
409                 TestInlinerInterface, TestReductionPatternInterface>();
410   allowUnknownOperations();
411 
412   // Instantiate our fallback op interface that we'll use on specific
413   // unregistered op.
414   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
415 }
416 TestDialect::~TestDialect() {
417   delete static_cast<TestOpEffectInterfaceFallback *>(
418       fallbackEffectOpInterfaces);
419 }
420 
421 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
422                                             Type type, Location loc) {
423   return builder.create<TestOpConstant>(loc, type, value);
424 }
425 
426 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
427     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
428     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
429     ::mlir::RegionRange regions,
430     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
431   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
432   return ::mlir::success();
433 }
434 
435 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
436                                                OperationName opName) {
437   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
438       typeID == TypeID::get<TestEffectOpInterface>())
439     return fallbackEffectOpInterfaces;
440   return nullptr;
441 }
442 
443 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
444                                                     NamedAttribute namedAttr) {
445   if (namedAttr.getName() == "test.invalid_attr")
446     return op->emitError() << "invalid to use 'test.invalid_attr'";
447   return success();
448 }
449 
450 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
451                                                     unsigned regionIndex,
452                                                     unsigned argIndex,
453                                                     NamedAttribute namedAttr) {
454   if (namedAttr.getName() == "test.invalid_attr")
455     return op->emitError() << "invalid to use 'test.invalid_attr'";
456   return success();
457 }
458 
459 LogicalResult
460 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
461                                          unsigned resultIndex,
462                                          NamedAttribute namedAttr) {
463   if (namedAttr.getName() == "test.invalid_attr")
464     return op->emitError() << "invalid to use 'test.invalid_attr'";
465   return success();
466 }
467 
468 Optional<Dialect::ParseOpHook>
469 TestDialect::getParseOperationHook(StringRef opName) const {
470   if (opName == "test.dialect_custom_printer") {
471     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
472       return parser.parseKeyword("custom_format");
473     }};
474   }
475   if (opName == "test.dialect_custom_format_fallback") {
476     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
477       return parser.parseKeyword("custom_format_fallback");
478     }};
479   }
480   if (opName == "test.dialect_custom_printer.with.dot") {
481     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
482       return ParseResult::success();
483     }};
484   }
485   return None;
486 }
487 
488 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
489 TestDialect::getOperationPrinter(Operation *op) const {
490   StringRef opName = op->getName().getStringRef();
491   if (opName == "test.dialect_custom_printer") {
492     return [](Operation *op, OpAsmPrinter &printer) {
493       printer.getStream() << " custom_format";
494     };
495   }
496   if (opName == "test.dialect_custom_format_fallback") {
497     return [](Operation *op, OpAsmPrinter &printer) {
498       printer.getStream() << " custom_format_fallback";
499     };
500   }
501   return {};
502 }
503 
504 //===----------------------------------------------------------------------===//
505 // TestBranchOp
506 //===----------------------------------------------------------------------===//
507 
508 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
509   assert(index == 0 && "invalid successor index");
510   return SuccessorOperands(getTargetOperandsMutable());
511 }
512 
513 //===----------------------------------------------------------------------===//
514 // TestProducingBranchOp
515 //===----------------------------------------------------------------------===//
516 
517 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
518   assert(index <= 1 && "invalid successor index");
519   if (index == 1)
520     return SuccessorOperands(getFirstOperandsMutable());
521   return SuccessorOperands(getSecondOperandsMutable());
522 }
523 
524 //===----------------------------------------------------------------------===//
525 // TestProducingBranchOp
526 //===----------------------------------------------------------------------===//
527 
528 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
529   assert(index <= 1 && "invalid successor index");
530   if (index == 0)
531     return SuccessorOperands(0, getSuccessOperandsMutable());
532   return SuccessorOperands(1, getErrorOperandsMutable());
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // TestDialectCanonicalizerOp
537 //===----------------------------------------------------------------------===//
538 
539 static LogicalResult
540 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
541                                PatternRewriter &rewriter) {
542   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
543       op, rewriter.getI32IntegerAttr(42));
544   return success();
545 }
546 
547 void TestDialect::getCanonicalizationPatterns(
548     RewritePatternSet &results) const {
549   results.add(&dialectCanonicalizationPattern);
550 }
551 
552 //===----------------------------------------------------------------------===//
553 // TestCallOp
554 //===----------------------------------------------------------------------===//
555 
556 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
557   // Check that the callee attribute was specified.
558   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
559   if (!fnAttr)
560     return emitOpError("requires a 'callee' symbol reference attribute");
561   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
562     return emitOpError() << "'" << fnAttr.getValue()
563                          << "' does not reference a valid function";
564   return success();
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // TestFoldToCallOp
569 //===----------------------------------------------------------------------===//
570 
571 namespace {
572 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
573   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
574 
575   LogicalResult matchAndRewrite(FoldToCallOp op,
576                                 PatternRewriter &rewriter) const override {
577     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
578                                               op.getCalleeAttr(), ValueRange());
579     return success();
580   }
581 };
582 } // namespace
583 
584 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
585                                                MLIRContext *context) {
586   results.add<FoldToCallOpPattern>(context);
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // Test Format* operations
591 //===----------------------------------------------------------------------===//
592 
593 //===----------------------------------------------------------------------===//
594 // Parsing
595 
596 static ParseResult parseCustomOptionalOperand(
597     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
598   if (succeeded(parser.parseOptionalLParen())) {
599     optOperand.emplace();
600     if (parser.parseOperand(*optOperand) || parser.parseRParen())
601       return failure();
602   }
603   return success();
604 }
605 
606 static ParseResult parseCustomDirectiveOperands(
607     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
608     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
609     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
610   if (parser.parseOperand(operand))
611     return failure();
612   if (succeeded(parser.parseOptionalComma())) {
613     optOperand.emplace();
614     if (parser.parseOperand(*optOperand))
615       return failure();
616   }
617   if (parser.parseArrow() || parser.parseLParen() ||
618       parser.parseOperandList(varOperands) || parser.parseRParen())
619     return failure();
620   return success();
621 }
622 static ParseResult
623 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
624                             Type &optOperandType,
625                             SmallVectorImpl<Type> &varOperandTypes) {
626   if (parser.parseColon())
627     return failure();
628 
629   if (parser.parseType(operandType))
630     return failure();
631   if (succeeded(parser.parseOptionalComma())) {
632     if (parser.parseType(optOperandType))
633       return failure();
634   }
635   if (parser.parseArrow() || parser.parseLParen() ||
636       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
637     return failure();
638   return success();
639 }
640 static ParseResult
641 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
642                                  Type optOperandType,
643                                  const SmallVectorImpl<Type> &varOperandTypes) {
644   if (parser.parseKeyword("type_refs_capture"))
645     return failure();
646 
647   Type operandType2, optOperandType2;
648   SmallVector<Type, 1> varOperandTypes2;
649   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
650                                   varOperandTypes2))
651     return failure();
652 
653   if (operandType != operandType2 || optOperandType != optOperandType2 ||
654       varOperandTypes != varOperandTypes2)
655     return failure();
656 
657   return success();
658 }
659 static ParseResult parseCustomDirectiveOperandsAndTypes(
660     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
661     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
662     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
663     Type &operandType, Type &optOperandType,
664     SmallVectorImpl<Type> &varOperandTypes) {
665   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
666       parseCustomDirectiveResults(parser, operandType, optOperandType,
667                                   varOperandTypes))
668     return failure();
669   return success();
670 }
671 static ParseResult parseCustomDirectiveRegions(
672     OpAsmParser &parser, Region &region,
673     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
674   if (parser.parseRegion(region))
675     return failure();
676   if (failed(parser.parseOptionalComma()))
677     return success();
678   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
679   if (parser.parseRegion(*varRegion))
680     return failure();
681   varRegions.emplace_back(std::move(varRegion));
682   return success();
683 }
684 static ParseResult
685 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
686                                SmallVectorImpl<Block *> &varSuccessors) {
687   if (parser.parseSuccessor(successor))
688     return failure();
689   if (failed(parser.parseOptionalComma()))
690     return success();
691   Block *varSuccessor;
692   if (parser.parseSuccessor(varSuccessor))
693     return failure();
694   varSuccessors.append(2, varSuccessor);
695   return success();
696 }
697 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
698                                                   IntegerAttr &attr,
699                                                   IntegerAttr &optAttr) {
700   if (parser.parseAttribute(attr))
701     return failure();
702   if (succeeded(parser.parseOptionalComma())) {
703     if (parser.parseAttribute(optAttr))
704       return failure();
705   }
706   return success();
707 }
708 
709 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
710                                                 NamedAttrList &attrs) {
711   return parser.parseOptionalAttrDict(attrs);
712 }
713 static ParseResult parseCustomDirectiveOptionalOperandRef(
714     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
715   int64_t operandCount = 0;
716   if (parser.parseInteger(operandCount))
717     return failure();
718   bool expectedOptionalOperand = operandCount == 0;
719   return success(expectedOptionalOperand != optOperand.has_value());
720 }
721 
722 //===----------------------------------------------------------------------===//
723 // Printing
724 
725 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
726                                        Value optOperand) {
727   if (optOperand)
728     printer << "(" << optOperand << ") ";
729 }
730 
731 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
732                                          Value operand, Value optOperand,
733                                          OperandRange varOperands) {
734   printer << operand;
735   if (optOperand)
736     printer << ", " << optOperand;
737   printer << " -> (" << varOperands << ")";
738 }
739 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
740                                         Type operandType, Type optOperandType,
741                                         TypeRange varOperandTypes) {
742   printer << " : " << operandType;
743   if (optOperandType)
744     printer << ", " << optOperandType;
745   printer << " -> (" << varOperandTypes << ")";
746 }
747 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
748                                              Operation *op, Type operandType,
749                                              Type optOperandType,
750                                              TypeRange varOperandTypes) {
751   printer << " type_refs_capture ";
752   printCustomDirectiveResults(printer, op, operandType, optOperandType,
753                               varOperandTypes);
754 }
755 static void printCustomDirectiveOperandsAndTypes(
756     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
757     OperandRange varOperands, Type operandType, Type optOperandType,
758     TypeRange varOperandTypes) {
759   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
760   printCustomDirectiveResults(printer, op, operandType, optOperandType,
761                               varOperandTypes);
762 }
763 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
764                                         Region &region,
765                                         MutableArrayRef<Region> varRegions) {
766   printer.printRegion(region);
767   if (!varRegions.empty()) {
768     printer << ", ";
769     for (Region &region : varRegions)
770       printer.printRegion(region);
771   }
772 }
773 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
774                                            Block *successor,
775                                            SuccessorRange varSuccessors) {
776   printer << successor;
777   if (!varSuccessors.empty())
778     printer << ", " << varSuccessors.front();
779 }
780 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
781                                            Attribute attribute,
782                                            Attribute optAttribute) {
783   printer << attribute;
784   if (optAttribute)
785     printer << ", " << optAttribute;
786 }
787 
788 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
789                                          DictionaryAttr attrs) {
790   printer.printOptionalAttrDict(attrs.getValue());
791 }
792 
793 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
794                                                    Operation *op,
795                                                    Value optOperand) {
796   printer << (optOperand ? "1" : "0");
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // Test IsolatedRegionOp - parse passthrough region arguments.
801 //===----------------------------------------------------------------------===//
802 
803 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
804                                     OperationState &result) {
805   // Parse the input operand.
806   OpAsmParser::Argument argInfo;
807   argInfo.type = parser.getBuilder().getIndexType();
808   if (parser.parseOperand(argInfo.ssaName) ||
809       parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
810     return failure();
811 
812   // Parse the body region, and reuse the operand info as the argument info.
813   Region *body = result.addRegion();
814   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
815 }
816 
817 void IsolatedRegionOp::print(OpAsmPrinter &p) {
818   p << "test.isolated_region ";
819   p.printOperand(getOperand());
820   p.shadowRegionArgs(getRegion(), getOperand());
821   p << ' ';
822   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // Test SSACFGRegionOp
827 //===----------------------------------------------------------------------===//
828 
829 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
830   return RegionKind::SSACFG;
831 }
832 
833 //===----------------------------------------------------------------------===//
834 // Test GraphRegionOp
835 //===----------------------------------------------------------------------===//
836 
837 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
838   return RegionKind::Graph;
839 }
840 
841 //===----------------------------------------------------------------------===//
842 // Test AffineScopeOp
843 //===----------------------------------------------------------------------===//
844 
845 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
846   // Parse the body region, and reuse the operand info as the argument info.
847   Region *body = result.addRegion();
848   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
849 }
850 
851 void AffineScopeOp::print(OpAsmPrinter &p) {
852   p << "test.affine_scope ";
853   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // Test parser.
858 //===----------------------------------------------------------------------===//
859 
860 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
861                                          OperationState &result) {
862   if (parser.parseOptionalColon())
863     return success();
864   uint64_t numResults;
865   if (parser.parseInteger(numResults))
866     return failure();
867 
868   IndexType type = parser.getBuilder().getIndexType();
869   for (unsigned i = 0; i < numResults; ++i)
870     result.addTypes(type);
871   return success();
872 }
873 
874 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
875   if (unsigned numResults = getNumResults())
876     p << " : " << numResults;
877 }
878 
879 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
880                                          OperationState &result) {
881   StringRef keyword;
882   if (parser.parseKeyword(&keyword))
883     return failure();
884   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
885   return success();
886 }
887 
888 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
889 
890 //===----------------------------------------------------------------------===//
891 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
892 
893 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
894                                     OperationState &result) {
895   if (parser.parseKeyword("wraps"))
896     return failure();
897 
898   // Parse the wrapped op in a region
899   Region &body = *result.addRegion();
900   body.push_back(new Block);
901   Block &block = body.back();
902   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
903   if (!wrappedOp)
904     return failure();
905 
906   // Create a return terminator in the inner region, pass as operand to the
907   // terminator the returned values from the wrapped operation.
908   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
909   OpBuilder builder(parser.getContext());
910   builder.setInsertionPointToEnd(&block);
911   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
912 
913   // Get the results type for the wrapping op from the terminator operands.
914   Operation &returnOp = body.back().back();
915   result.types.append(returnOp.operand_type_begin(),
916                       returnOp.operand_type_end());
917 
918   // Use the location of the wrapped op for the "test.wrapping_region" op.
919   result.location = wrappedOp->getLoc();
920 
921   return success();
922 }
923 
924 void WrappingRegionOp::print(OpAsmPrinter &p) {
925   p << " wraps ";
926   p.printGenericOp(&getRegion().front().front());
927 }
928 
929 //===----------------------------------------------------------------------===//
930 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
931 //   parseGenericOperationAfterOpName
932 //   parseCustomOperationName
933 //===----------------------------------------------------------------------===//
934 
935 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
936                                          OperationState &result) {
937 
938   SMLoc loc = parser.getCurrentLocation();
939   Location currLocation = parser.getEncodedSourceLoc(loc);
940 
941   // Parse the operands.
942   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
943   if (parser.parseOperandList(operands))
944     return failure();
945 
946   // Check if we are parsing the pretty-printed version
947   //  test.pretty_printed_region start <inner-op> end : <functional-type>
948   // Else fallback to parsing the "non pretty-printed" version.
949   if (!succeeded(parser.parseOptionalKeyword("start")))
950     return parser.parseGenericOperationAfterOpName(
951         result, llvm::makeArrayRef(operands));
952 
953   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
954   if (failed(parseOpNameInfo))
955     return failure();
956 
957   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
958 
959   FunctionType opFntype;
960   Optional<Location> explicitLoc;
961   if (parser.parseKeyword("end") || parser.parseColon() ||
962       parser.parseType(opFntype) ||
963       parser.parseOptionalLocationSpecifier(explicitLoc))
964     return failure();
965 
966   // If location of the op is explicitly provided, then use it; Else use
967   // the parser's current location.
968   Location opLoc = explicitLoc.value_or(currLocation);
969 
970   // Derive the SSA-values for op's operands.
971   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
972                              result.operands))
973     return failure();
974 
975   // Add a region for op.
976   Region &region = *result.addRegion();
977 
978   // Create a basic-block inside op's region.
979   Block &block = region.emplaceBlock();
980 
981   // Create and insert an "inner-op" operation in the block.
982   // Just for testing purposes, we can assume that inner op is a binary op with
983   // result and operand types all same as the test-op's first operand.
984   Type innerOpType = opFntype.getInput(0);
985   Value lhs = block.addArgument(innerOpType, opLoc);
986   Value rhs = block.addArgument(innerOpType, opLoc);
987 
988   OpBuilder builder(parser.getBuilder().getContext());
989   builder.setInsertionPointToStart(&block);
990 
991   Operation *innerOp =
992       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
993 
994   // Insert a return statement in the block returning the inner-op's result.
995   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
996 
997   // Populate the op operation-state with result-type and location.
998   result.addTypes(opFntype.getResults());
999   result.location = innerOp->getLoc();
1000 
1001   return success();
1002 }
1003 
1004 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
1005   p << ' ';
1006   p.printOperands(getOperands());
1007 
1008   Operation &innerOp = getRegion().front().front();
1009   // Assuming that region has a single non-terminator inner-op, if the inner-op
1010   // meets some criteria (which in this case is a simple one  based on the name
1011   // of inner-op), then we can print the entire region in a succinct way.
1012   // Here we assume that the prototype of "special.op" can be trivially derived
1013   // while parsing it back.
1014   if (innerOp.getName().getStringRef().equals("special.op")) {
1015     p << " start special.op end";
1016   } else {
1017     p << " (";
1018     p.printRegion(getRegion());
1019     p << ")";
1020   }
1021 
1022   p << " : ";
1023   p.printFunctionalType(*this);
1024 }
1025 
1026 //===----------------------------------------------------------------------===//
1027 // Test PolyForOp - parse list of region arguments.
1028 //===----------------------------------------------------------------------===//
1029 
1030 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
1031   SmallVector<OpAsmParser::Argument, 4> ivsInfo;
1032   // Parse list of region arguments without a delimiter.
1033   if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
1034     return failure();
1035 
1036   // Parse the body region.
1037   Region *body = result.addRegion();
1038   for (auto &iv : ivsInfo)
1039     iv.type = parser.getBuilder().getIndexType();
1040   return parser.parseRegion(*body, ivsInfo);
1041 }
1042 
1043 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
1044 
1045 void PolyForOp::getAsmBlockArgumentNames(Region &region,
1046                                          OpAsmSetValueNameFn setNameFn) {
1047   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
1048   if (!arrayAttr)
1049     return;
1050   auto args = getRegion().front().getArguments();
1051   auto e = std::min(arrayAttr.size(), args.size());
1052   for (unsigned i = 0; i < e; ++i) {
1053     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
1054       setNameFn(args[i], strAttr.getValue());
1055   }
1056 }
1057 
1058 //===----------------------------------------------------------------------===//
1059 // Test removing op with inner ops.
1060 //===----------------------------------------------------------------------===//
1061 
1062 namespace {
1063 struct TestRemoveOpWithInnerOps
1064     : public OpRewritePattern<TestOpWithRegionPattern> {
1065   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
1066 
1067   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
1068 
1069   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
1070                                 PatternRewriter &rewriter) const override {
1071     rewriter.eraseOp(op);
1072     return success();
1073   }
1074 };
1075 } // namespace
1076 
1077 void TestOpWithRegionPattern::getCanonicalizationPatterns(
1078     RewritePatternSet &results, MLIRContext *context) {
1079   results.add<TestRemoveOpWithInnerOps>(context);
1080 }
1081 
1082 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
1083   return getOperand();
1084 }
1085 
1086 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
1087   return getValue();
1088 }
1089 
1090 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
1091     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
1092   for (Value input : this->getOperands()) {
1093     results.push_back(input);
1094   }
1095   return success();
1096 }
1097 
1098 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
1099   assert(operands.size() == 1);
1100   if (operands.front()) {
1101     (*this)->setAttr("attr", operands.front());
1102     return getResult();
1103   }
1104   return {};
1105 }
1106 
1107 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
1108   return getOperand();
1109 }
1110 
1111 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
1112     MLIRContext *, Optional<Location> location, ValueRange operands,
1113     DictionaryAttr attributes, RegionRange regions,
1114     SmallVectorImpl<Type> &inferredReturnTypes) {
1115   if (operands[0].getType() != operands[1].getType()) {
1116     return emitOptionalError(location, "operand type mismatch ",
1117                              operands[0].getType(), " vs ",
1118                              operands[1].getType());
1119   }
1120   inferredReturnTypes.assign({operands[0].getType()});
1121   return success();
1122 }
1123 
1124 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
1125     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1126     DictionaryAttr attributes, RegionRange regions,
1127     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1128   // Create return type consisting of the last element of the first operand.
1129   auto operandType = operands.front().getType();
1130   auto sval = operandType.dyn_cast<ShapedType>();
1131   if (!sval) {
1132     return emitOptionalError(location, "only shaped type operands allowed");
1133   }
1134   int64_t dim =
1135       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1136   auto type = IntegerType::get(context, 17);
1137   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1138   return success();
1139 }
1140 
1141 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1142     OpBuilder &builder, ValueRange operands,
1143     llvm::SmallVectorImpl<Value> &shapes) {
1144   shapes = SmallVector<Value, 1>{
1145       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1146   return success();
1147 }
1148 
1149 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1150     OpBuilder &builder, ValueRange operands,
1151     llvm::SmallVectorImpl<Value> &shapes) {
1152   Location loc = getLoc();
1153   shapes.reserve(operands.size());
1154   for (Value operand : llvm::reverse(operands)) {
1155     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1156     auto currShape = llvm::to_vector<4>(
1157         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1158           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1159         }));
1160     shapes.push_back(builder.create<tensor::FromElementsOp>(
1161         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1162         currShape));
1163   }
1164   return success();
1165 }
1166 
1167 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1168     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1169   Location loc = getLoc();
1170   shapes.reserve(getNumOperands());
1171   for (Value operand : llvm::reverse(getOperands())) {
1172     auto currShape = llvm::to_vector<4>(llvm::map_range(
1173         llvm::seq<int64_t>(
1174             0, operand.getType().cast<RankedTensorType>().getRank()),
1175         [&](int64_t dim) -> Value {
1176           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1177         }));
1178     shapes.emplace_back(std::move(currShape));
1179   }
1180   return success();
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // Test SideEffect interfaces
1185 //===----------------------------------------------------------------------===//
1186 
1187 namespace {
1188 /// A test resource for side effects.
1189 struct TestResource : public SideEffects::Resource::Base<TestResource> {
1190   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1191 
1192   StringRef getName() final { return "<Test>"; }
1193 };
1194 } // namespace
1195 
1196 static void testSideEffectOpGetEffect(
1197     Operation *op,
1198     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1199         &effects) {
1200   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1201   if (!effectsAttr)
1202     return;
1203 
1204   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1205 }
1206 
1207 void SideEffectOp::getEffects(
1208     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1209   // Check for an effects attribute on the op instance.
1210   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1211   if (!effectsAttr)
1212     return;
1213 
1214   // If there is one, it is an array of dictionary attributes that hold
1215   // information on the effects of this operation.
1216   for (Attribute element : effectsAttr) {
1217     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1218 
1219     // Get the specific memory effect.
1220     MemoryEffects::Effect *effect =
1221         StringSwitch<MemoryEffects::Effect *>(
1222             effectElement.get("effect").cast<StringAttr>().getValue())
1223             .Case("allocate", MemoryEffects::Allocate::get())
1224             .Case("free", MemoryEffects::Free::get())
1225             .Case("read", MemoryEffects::Read::get())
1226             .Case("write", MemoryEffects::Write::get());
1227 
1228     // Check for a non-default resource to use.
1229     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1230     if (effectElement.get("test_resource"))
1231       resource = TestResource::get();
1232 
1233     // Check for a result to affect.
1234     if (effectElement.get("on_result"))
1235       effects.emplace_back(effect, getResult(), resource);
1236     else if (Attribute ref = effectElement.get("on_reference"))
1237       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1238     else
1239       effects.emplace_back(effect, resource);
1240   }
1241 }
1242 
1243 void SideEffectOp::getEffects(
1244     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1245   testSideEffectOpGetEffect(getOperation(), effects);
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // StringAttrPrettyNameOp
1250 //===----------------------------------------------------------------------===//
1251 
1252 // This op has fancy handling of its SSA result name.
1253 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1254                                           OperationState &result) {
1255   // Add the result types.
1256   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1257     result.addTypes(parser.getBuilder().getIntegerType(32));
1258 
1259   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1260     return failure();
1261 
1262   // If the attribute dictionary contains no 'names' attribute, infer it from
1263   // the SSA name (if specified).
1264   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1265     return attr.getName() == "names";
1266   });
1267 
1268   // If there was no name specified, check to see if there was a useful name
1269   // specified in the asm file.
1270   if (hadNames || parser.getNumResults() == 0)
1271     return success();
1272 
1273   SmallVector<StringRef, 4> names;
1274   auto *context = result.getContext();
1275 
1276   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1277     auto resultName = parser.getResultName(i);
1278     StringRef nameStr;
1279     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1280       nameStr = resultName.first;
1281 
1282     names.push_back(nameStr);
1283   }
1284 
1285   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1286   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1287   return success();
1288 }
1289 
1290 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1291   // Note that we only need to print the "name" attribute if the asmprinter
1292   // result name disagrees with it.  This can happen in strange cases, e.g.
1293   // when there are conflicts.
1294   bool namesDisagree = getNames().size() != getNumResults();
1295 
1296   SmallString<32> resultNameStr;
1297   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1298     resultNameStr.clear();
1299     llvm::raw_svector_ostream tmpStream(resultNameStr);
1300     p.printOperand(getResult(i), tmpStream);
1301 
1302     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1303     if (!expectedName ||
1304         tmpStream.str().drop_front() != expectedName.getValue()) {
1305       namesDisagree = true;
1306     }
1307   }
1308 
1309   if (namesDisagree)
1310     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1311   else
1312     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1313 }
1314 
1315 // We set the SSA name in the asm syntax to the contents of the name
1316 // attribute.
1317 void StringAttrPrettyNameOp::getAsmResultNames(
1318     function_ref<void(Value, StringRef)> setNameFn) {
1319 
1320   auto value = getNames();
1321   for (size_t i = 0, e = value.size(); i != e; ++i)
1322     if (auto str = value[i].dyn_cast<StringAttr>())
1323       if (!str.getValue().empty())
1324         setNameFn(getResult(i), str.getValue());
1325 }
1326 
1327 void CustomResultsNameOp::getAsmResultNames(
1328     function_ref<void(Value, StringRef)> setNameFn) {
1329   ArrayAttr value = getNames();
1330   for (size_t i = 0, e = value.size(); i != e; ++i)
1331     if (auto str = value[i].dyn_cast<StringAttr>())
1332       if (!str.getValue().empty())
1333         setNameFn(getResult(i), str.getValue());
1334 }
1335 
1336 //===----------------------------------------------------------------------===//
1337 // ResultTypeWithTraitOp
1338 //===----------------------------------------------------------------------===//
1339 
1340 LogicalResult ResultTypeWithTraitOp::verify() {
1341   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1342     return success();
1343   return emitError("result type should have trait 'TestTypeTrait'");
1344 }
1345 
1346 //===----------------------------------------------------------------------===//
1347 // AttrWithTraitOp
1348 //===----------------------------------------------------------------------===//
1349 
1350 LogicalResult AttrWithTraitOp::verify() {
1351   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1352     return success();
1353   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1354 }
1355 
1356 //===----------------------------------------------------------------------===//
1357 // RegionIfOp
1358 //===----------------------------------------------------------------------===//
1359 
1360 void RegionIfOp::print(OpAsmPrinter &p) {
1361   p << " ";
1362   p.printOperands(getOperands());
1363   p << ": " << getOperandTypes();
1364   p.printArrowTypeList(getResultTypes());
1365   p << " then ";
1366   p.printRegion(getThenRegion(),
1367                 /*printEntryBlockArgs=*/true,
1368                 /*printBlockTerminators=*/true);
1369   p << " else ";
1370   p.printRegion(getElseRegion(),
1371                 /*printEntryBlockArgs=*/true,
1372                 /*printBlockTerminators=*/true);
1373   p << " join ";
1374   p.printRegion(getJoinRegion(),
1375                 /*printEntryBlockArgs=*/true,
1376                 /*printBlockTerminators=*/true);
1377 }
1378 
1379 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1380   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1381   SmallVector<Type, 2> operandTypes;
1382 
1383   result.regions.reserve(3);
1384   Region *thenRegion = result.addRegion();
1385   Region *elseRegion = result.addRegion();
1386   Region *joinRegion = result.addRegion();
1387 
1388   // Parse operand, type and arrow type lists.
1389   if (parser.parseOperandList(operandInfos) ||
1390       parser.parseColonTypeList(operandTypes) ||
1391       parser.parseArrowTypeList(result.types))
1392     return failure();
1393 
1394   // Parse all attached regions.
1395   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1396       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1397       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1398     return failure();
1399 
1400   return parser.resolveOperands(operandInfos, operandTypes,
1401                                 parser.getCurrentLocation(), result.operands);
1402 }
1403 
1404 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1405   assert(index && *index < 2 && "invalid region index");
1406   return getOperands();
1407 }
1408 
1409 void RegionIfOp::getSuccessorRegions(
1410     Optional<unsigned> index, ArrayRef<Attribute> operands,
1411     SmallVectorImpl<RegionSuccessor> &regions) {
1412   // We always branch to the join region.
1413   if (index.hasValue()) {
1414     if (index.getValue() < 2)
1415       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1416     else
1417       regions.push_back(RegionSuccessor(getResults()));
1418     return;
1419   }
1420 
1421   // The then and else regions are the entry regions of this op.
1422   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1423   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1424 }
1425 
1426 void RegionIfOp::getRegionInvocationBounds(
1427     ArrayRef<Attribute> operands,
1428     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1429   // Each region is invoked at most once.
1430   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1431 }
1432 
1433 //===----------------------------------------------------------------------===//
1434 // AnyCondOp
1435 //===----------------------------------------------------------------------===//
1436 
1437 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1438                                     ArrayRef<Attribute> operands,
1439                                     SmallVectorImpl<RegionSuccessor> &regions) {
1440   // The parent op branches into the only region, and the region branches back
1441   // to the parent op.
1442   if (!index)
1443     regions.emplace_back(&getRegion());
1444   else
1445     regions.emplace_back(getResults());
1446 }
1447 
1448 void AnyCondOp::getRegionInvocationBounds(
1449     ArrayRef<Attribute> operands,
1450     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1451   invocationBounds.emplace_back(1, 1);
1452 }
1453 
1454 //===----------------------------------------------------------------------===//
1455 // SingleNoTerminatorCustomAsmOp
1456 //===----------------------------------------------------------------------===//
1457 
1458 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1459                                                  OperationState &state) {
1460   Region *body = state.addRegion();
1461   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1462     return failure();
1463   return success();
1464 }
1465 
1466 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1467   printer.printRegion(
1468       getRegion(), /*printEntryBlockArgs=*/false,
1469       // This op has a single block without terminators. But explicitly mark
1470       // as not printing block terminators for testing.
1471       /*printBlockTerminators=*/false);
1472 }
1473 
1474 //===----------------------------------------------------------------------===//
1475 // TestVerifiersOp
1476 //===----------------------------------------------------------------------===//
1477 
1478 LogicalResult TestVerifiersOp::verify() {
1479   if (!getRegion().hasOneBlock())
1480     return emitOpError("`hasOneBlock` trait hasn't been verified");
1481 
1482   Operation *definingOp = getInput().getDefiningOp();
1483   if (definingOp && failed(mlir::verify(definingOp)))
1484     return emitOpError("operand hasn't been verified");
1485 
1486   emitRemark("success run of verifier");
1487 
1488   return success();
1489 }
1490 
1491 LogicalResult TestVerifiersOp::verifyRegions() {
1492   if (!getRegion().hasOneBlock())
1493     return emitOpError("`hasOneBlock` trait hasn't been verified");
1494 
1495   for (Block &block : getRegion())
1496     for (Operation &op : block)
1497       if (failed(mlir::verify(&op)))
1498         return emitOpError("nested op hasn't been verified");
1499 
1500   emitRemark("success run of region verifier");
1501 
1502   return success();
1503 }
1504 
1505 //===----------------------------------------------------------------------===//
1506 // Test InferIntRangeInterface
1507 //===----------------------------------------------------------------------===//
1508 
1509 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1510                                          SetIntRangeFn setResultRanges) {
1511   setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
1512 }
1513 
1514 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
1515                                           OperationState &result) {
1516   if (parser.parseOptionalAttrDict(result.attributes))
1517     return failure();
1518 
1519   // Parse the input argument
1520   OpAsmParser::Argument argInfo;
1521   argInfo.type = parser.getBuilder().getIndexType();
1522   if (failed(parser.parseArgument(argInfo)))
1523     return failure();
1524 
1525   // Parse the body region, and reuse the operand info as the argument info.
1526   Region *body = result.addRegion();
1527   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
1528 }
1529 
1530 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
1531   p.printOptionalAttrDict((*this)->getAttrs());
1532   p << ' ';
1533   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
1534                         /*omitType=*/true);
1535   p << ' ';
1536   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
1537 }
1538 
1539 void TestWithBoundsRegionOp::inferResultRanges(
1540     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1541   Value arg = getRegion().getArgument(0);
1542   setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
1543 }
1544 
1545 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1546                                         SetIntRangeFn setResultRanges) {
1547   const ConstantIntRanges &range = argRanges[0];
1548   APInt one(range.umin().getBitWidth(), 1);
1549   setResultRanges(getResult(),
1550                   {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
1551                    range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
1552 }
1553 
1554 void TestReflectBoundsOp::inferResultRanges(
1555     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1556   const ConstantIntRanges &range = argRanges[0];
1557   MLIRContext *ctx = getContext();
1558   Builder b(ctx);
1559   setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
1560   setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
1561   setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
1562   setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
1563   setResultRanges(getResult(), range);
1564 }
1565 
1566 #include "TestOpEnums.cpp.inc"
1567 #include "TestOpInterfaces.cpp.inc"
1568 #include "TestTypeInterfaces.cpp.inc"
1569 
1570 #define GET_OP_CLASSES
1571 #include "TestOps.cpp.inc"
1572