1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/ExtensibleDialect.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Verifier.h"
28 #include "mlir/Interfaces/InferIntRangeInterface.h"
29 #include "mlir/Reducer/ReductionPatternInterface.h"
30 #include "mlir/Transforms/FoldUtils.h"
31 #include "mlir/Transforms/InliningUtils.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/ADT/StringSwitch.h"
35 
36 // Include this before the using namespace lines below to
37 // test that we don't have namespace dependencies.
38 #include "TestOpsDialect.cpp.inc"
39 
40 using namespace mlir;
41 using namespace test;
42 
registerTestDialect(DialectRegistry & registry)43 void test::registerTestDialect(DialectRegistry &registry) {
44   registry.insert<TestDialect>();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // External Elements Data
49 //===----------------------------------------------------------------------===//
50 
getData() const51 ArrayRef<uint64_t> TestExternalElementsData::getData() const {
52   ArrayRef<char> data = AsmResourceBlob::getData();
53   return ArrayRef<uint64_t>((const uint64_t *)data.data(),
54                             data.size() / sizeof(uint64_t));
55 }
56 
57 TestExternalElementsData
allocate(size_t numElements)58 TestExternalElementsData::allocate(size_t numElements) {
59   return TestExternalElementsData(
60       llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements),
61       [](const uint64_t *data, size_t) { delete[] data; },
62       /*dataIsMutable=*/true);
63 }
64 
65 const TestExternalElementsData *
getData(StringRef name) const66 TestExternalElementsDataManager::getData(StringRef name) const {
67   auto it = dataMap.find(name);
68   return it != dataMap.end() ? &*it->second : nullptr;
69 }
70 
71 std::pair<TestExternalElementsDataManager::DataMap::iterator, bool>
insert(StringRef name)72 TestExternalElementsDataManager::insert(StringRef name) {
73   auto it = dataMap.try_emplace(name, nullptr);
74   if (it.second)
75     return it;
76 
77   llvm::SmallString<32> nameStorage(name);
78   nameStorage.push_back('_');
79   size_t nameCounter = 1;
80   do {
81     nameStorage += std::to_string(nameCounter++);
82     auto it = dataMap.try_emplace(nameStorage, nullptr);
83     if (it.second)
84       return it;
85     nameStorage.resize(name.size() + 1);
86   } while (true);
87 }
88 
setData(StringRef name,TestExternalElementsData && data)89 void TestExternalElementsDataManager::setData(StringRef name,
90                                               TestExternalElementsData &&data) {
91   auto it = dataMap.find(name);
92   assert(it != dataMap.end() && "data not registered");
93   it->second = std::make_unique<TestExternalElementsData>(std::move(data));
94 }
95 
96 //===----------------------------------------------------------------------===//
97 // TestDialect Interfaces
98 //===----------------------------------------------------------------------===//
99 
100 namespace {
101 
102 /// Testing the correctness of some traits.
103 static_assert(
104     llvm::is_detected<OpTrait::has_implicit_terminator_t,
105                       SingleBlockImplicitTerminatorOp>::value,
106     "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
107 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
108                   SingleBlockImplicitTerminatorOp>::value,
109               "hasSingleBlockImplicitTerminator does not match "
110               "SingleBlockImplicitTerminatorOp");
111 
112 // Test support for interacting with the AsmPrinter.
113 struct TestOpAsmInterface : public OpAsmDialectInterface {
114   using OpAsmDialectInterface::OpAsmDialectInterface;
115 
116   //===------------------------------------------------------------------===//
117   // Aliases
118   //===------------------------------------------------------------------===//
119 
getAlias__anonf77d94720211::TestOpAsmInterface120   AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
121     StringAttr strAttr = attr.dyn_cast<StringAttr>();
122     if (!strAttr)
123       return AliasResult::NoAlias;
124 
125     // Check the contents of the string attribute to see what the test alias
126     // should be named.
127     Optional<StringRef> aliasName =
128         StringSwitch<Optional<StringRef>>(strAttr.getValue())
129             .Case("alias_test:dot_in_name", StringRef("test.alias"))
130             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
131             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
132             .Case("alias_test:sanitize_conflict_a",
133                   StringRef("test_alias_conflict0"))
134             .Case("alias_test:sanitize_conflict_b",
135                   StringRef("test_alias_conflict0_"))
136             .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
137             .Default(llvm::None);
138     if (!aliasName)
139       return AliasResult::NoAlias;
140 
141     os << *aliasName;
142     return AliasResult::FinalAlias;
143   }
144 
getAlias__anonf77d94720211::TestOpAsmInterface145   AliasResult getAlias(Type type, raw_ostream &os) const final {
146     if (auto tupleType = type.dyn_cast<TupleType>()) {
147       if (tupleType.size() > 0 &&
148           llvm::all_of(tupleType.getTypes(), [](Type elemType) {
149             return elemType.isa<SimpleAType>();
150           })) {
151         os << "test_tuple";
152         return AliasResult::FinalAlias;
153       }
154     }
155     if (auto intType = type.dyn_cast<TestIntegerType>()) {
156       if (intType.getSignedness() ==
157               TestIntegerType::SignednessSemantics::Unsigned &&
158           intType.getWidth() == 8) {
159         os << "test_ui8";
160         return AliasResult::FinalAlias;
161       }
162     }
163     if (auto recType = type.dyn_cast<TestRecursiveType>()) {
164       if (recType.getName() == "type_to_alias") {
165         // We only make alias for a specific recursive type.
166         os << "testrec";
167         return AliasResult::FinalAlias;
168       }
169     }
170     return AliasResult::NoAlias;
171   }
172 
173   //===------------------------------------------------------------------===//
174   // Resources
175   //===------------------------------------------------------------------===//
176 
177   std::string
getResourceKey__anonf77d94720211::TestOpAsmInterface178   getResourceKey(const AsmDialectResourceHandle &handle) const override {
179     return cast<TestExternalElementsDataHandle>(handle).getKey().str();
180   }
181 
182   FailureOr<AsmDialectResourceHandle>
declareResource__anonf77d94720211::TestOpAsmInterface183   declareResource(StringRef key) const final {
184     TestDialect *dialect = cast<TestDialect>(getDialect());
185     TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
186 
187     // Resolve the reference by inserting a new entry into the manager.
188     auto it = mgr.insert(key).first;
189     return TestExternalElementsDataHandle(&*it, dialect);
190   }
191 
parseResource__anonf77d94720211::TestOpAsmInterface192   LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
193     TestDialect *dialect = cast<TestDialect>(getDialect());
194     TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
195 
196     // The resource entries are external constant data.
197     auto blobAllocFn = [](unsigned size, unsigned align) {
198       assert(align == alignof(uint64_t) && "unexpected data alignment");
199       return TestExternalElementsData::allocate(size / sizeof(uint64_t));
200     };
201     FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn);
202     if (failed(blob))
203       return failure();
204 
205     mgr.setData(entry.getKey(), std::move(*blob));
206     return success();
207   }
208 
209   void
buildResources__anonf77d94720211::TestOpAsmInterface210   buildResources(Operation *op,
211                  const SetVector<AsmDialectResourceHandle> &referencedResources,
212                  AsmResourceBuilder &provider) const final {
213     for (const AsmDialectResourceHandle &handle : referencedResources) {
214       const auto &testHandle = cast<TestExternalElementsDataHandle>(handle);
215       provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData());
216     }
217   }
218 };
219 
220 struct TestDialectFoldInterface : public DialectFoldInterface {
221   using DialectFoldInterface::DialectFoldInterface;
222 
223   /// Registered hook to check if the given region, which is attached to an
224   /// operation that is *not* isolated from above, should be used when
225   /// materializing constants.
shouldMaterializeInto__anonf77d94720211::TestDialectFoldInterface226   bool shouldMaterializeInto(Region *region) const final {
227     // If this is a one region operation, then insert into it.
228     return isa<OneRegionOp>(region->getParentOp());
229   }
230 };
231 
232 /// This class defines the interface for handling inlining with standard
233 /// operations.
234 struct TestInlinerInterface : public DialectInlinerInterface {
235   using DialectInlinerInterface::DialectInlinerInterface;
236 
237   //===--------------------------------------------------------------------===//
238   // Analysis Hooks
239   //===--------------------------------------------------------------------===//
240 
isLegalToInline__anonf77d94720211::TestInlinerInterface241   bool isLegalToInline(Operation *call, Operation *callable,
242                        bool wouldBeCloned) const final {
243     // Don't allow inlining calls that are marked `noinline`.
244     return !call->hasAttr("noinline");
245   }
isLegalToInline__anonf77d94720211::TestInlinerInterface246   bool isLegalToInline(Region *, Region *, bool,
247                        BlockAndValueMapping &) const final {
248     // Inlining into test dialect regions is legal.
249     return true;
250   }
isLegalToInline__anonf77d94720211::TestInlinerInterface251   bool isLegalToInline(Operation *, Region *, bool,
252                        BlockAndValueMapping &) const final {
253     return true;
254   }
255 
shouldAnalyzeRecursively__anonf77d94720211::TestInlinerInterface256   bool shouldAnalyzeRecursively(Operation *op) const final {
257     // Analyze recursively if this is not a functional region operation, it
258     // froms a separate functional scope.
259     return !isa<FunctionalRegionOp>(op);
260   }
261 
262   //===--------------------------------------------------------------------===//
263   // Transformation Hooks
264   //===--------------------------------------------------------------------===//
265 
266   /// Handle the given inlined terminator by replacing it with a new operation
267   /// as necessary.
handleTerminator__anonf77d94720211::TestInlinerInterface268   void handleTerminator(Operation *op,
269                         ArrayRef<Value> valuesToRepl) const final {
270     // Only handle "test.return" here.
271     auto returnOp = dyn_cast<TestReturnOp>(op);
272     if (!returnOp)
273       return;
274 
275     // Replace the values directly with the return operands.
276     assert(returnOp.getNumOperands() == valuesToRepl.size());
277     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
278       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
279   }
280 
281   /// Attempt to materialize a conversion for a type mismatch between a call
282   /// from this dialect, and a callable region. This method should generate an
283   /// operation that takes 'input' as the only operand, and produces a single
284   /// result of 'resultType'. If a conversion can not be generated, nullptr
285   /// should be returned.
materializeCallConversion__anonf77d94720211::TestInlinerInterface286   Operation *materializeCallConversion(OpBuilder &builder, Value input,
287                                        Type resultType,
288                                        Location conversionLoc) const final {
289     // Only allow conversion for i16/i32 types.
290     if (!(resultType.isSignlessInteger(16) ||
291           resultType.isSignlessInteger(32)) ||
292         !(input.getType().isSignlessInteger(16) ||
293           input.getType().isSignlessInteger(32)))
294       return nullptr;
295     return builder.create<TestCastOp>(conversionLoc, resultType, input);
296   }
297 
processInlinedCallBlocks__anonf77d94720211::TestInlinerInterface298   void processInlinedCallBlocks(
299       Operation *call,
300       iterator_range<Region::iterator> inlinedBlocks) const final {
301     if (!isa<ConversionCallOp>(call))
302       return;
303 
304     // Set attributed on all ops in the inlined blocks.
305     for (Block &block : inlinedBlocks) {
306       block.walk([&](Operation *op) {
307         op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
308       });
309     }
310   }
311 };
312 
313 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
314 public:
TestReductionPatternInterface__anonf77d94720211::TestReductionPatternInterface315   TestReductionPatternInterface(Dialect *dialect)
316       : DialectReductionPatternInterface(dialect) {}
317 
populateReductionPatterns__anonf77d94720211::TestReductionPatternInterface318   void populateReductionPatterns(RewritePatternSet &patterns) const final {
319     populateTestReductionPatterns(patterns);
320   }
321 };
322 
323 } // namespace
324 
325 //===----------------------------------------------------------------------===//
326 // Dynamic operations
327 //===----------------------------------------------------------------------===//
328 
getDynamicGenericOp(TestDialect * dialect)329 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
330   return DynamicOpDefinition::get(
331       "dynamic_generic", dialect, [](Operation *op) { return success(); },
332       [](Operation *op) { return success(); });
333 }
334 
335 std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect * dialect)336 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
337   return DynamicOpDefinition::get(
338       "dynamic_one_operand_two_results", dialect,
339       [](Operation *op) {
340         if (op->getNumOperands() != 1) {
341           op->emitOpError()
342               << "expected 1 operand, but had " << op->getNumOperands();
343           return failure();
344         }
345         if (op->getNumResults() != 2) {
346           op->emitOpError()
347               << "expected 2 results, but had " << op->getNumResults();
348           return failure();
349         }
350         return success();
351       },
352       [](Operation *op) { return success(); });
353 }
354 
355 std::unique_ptr<DynamicOpDefinition>
getDynamicCustomParserPrinterOp(TestDialect * dialect)356 getDynamicCustomParserPrinterOp(TestDialect *dialect) {
357   auto verifier = [](Operation *op) {
358     if (op->getNumOperands() == 0 && op->getNumResults() == 0)
359       return success();
360     op->emitError() << "operation should have no operands and no results";
361     return failure();
362   };
363   auto regionVerifier = [](Operation *op) { return success(); };
364 
365   auto parser = [](OpAsmParser &parser, OperationState &state) {
366     return parser.parseKeyword("custom_keyword");
367   };
368 
369   auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
370     printer << op->getName() << " custom_keyword";
371   };
372 
373   return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
374                                   verifier, regionVerifier, parser, printer);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // TestDialect
379 //===----------------------------------------------------------------------===//
380 
381 static void testSideEffectOpGetEffect(
382     Operation *op,
383     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
384 
385 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
386 struct TestOpEffectInterfaceFallback
387     : public TestEffectOpInterface::FallbackModel<
388           TestOpEffectInterfaceFallback> {
classofTestOpEffectInterfaceFallback389   static bool classof(Operation *op) {
390     bool isSupportedOp =
391         op->getName().getStringRef() == "test.unregistered_side_effect_op";
392     assert(isSupportedOp && "Unexpected dispatch");
393     return isSupportedOp;
394   }
395 
396   void
getEffectsTestOpEffectInterfaceFallback397   getEffects(Operation *op,
398              SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
399                  &effects) const {
400     testSideEffectOpGetEffect(op, effects);
401   }
402 };
403 
initialize()404 void TestDialect::initialize() {
405   registerAttributes();
406   registerTypes();
407   addOperations<
408 #define GET_OP_LIST
409 #include "TestOps.cpp.inc"
410       >();
411   registerDynamicOp(getDynamicGenericOp(this));
412   registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
413   registerDynamicOp(getDynamicCustomParserPrinterOp(this));
414 
415   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
416                 TestInlinerInterface, TestReductionPatternInterface>();
417   allowUnknownOperations();
418 
419   // Instantiate our fallback op interface that we'll use on specific
420   // unregistered op.
421   fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
422 }
~TestDialect()423 TestDialect::~TestDialect() {
424   delete static_cast<TestOpEffectInterfaceFallback *>(
425       fallbackEffectOpInterfaces);
426 }
427 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)428 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
429                                             Type type, Location loc) {
430   return builder.create<TestOpConstant>(loc, type, value);
431 }
432 
inferReturnTypes(::mlir::MLIRContext * context,::llvm::Optional<::mlir::Location> location,::mlir::ValueRange operands,::mlir::DictionaryAttr attributes,::mlir::RegionRange regions,::llvm::SmallVectorImpl<::mlir::Type> & inferredReturnTypes)433 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
434     ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
435     ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
436     ::mlir::RegionRange regions,
437     ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
438   inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
439   return ::mlir::success();
440 }
441 
getRegisteredInterfaceForOp(TypeID typeID,OperationName opName)442 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
443                                                OperationName opName) {
444   if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
445       typeID == TypeID::get<TestEffectOpInterface>())
446     return fallbackEffectOpInterfaces;
447   return nullptr;
448 }
449 
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)450 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
451                                                     NamedAttribute namedAttr) {
452   if (namedAttr.getName() == "test.invalid_attr")
453     return op->emitError() << "invalid to use 'test.invalid_attr'";
454   return success();
455 }
456 
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)457 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
458                                                     unsigned regionIndex,
459                                                     unsigned argIndex,
460                                                     NamedAttribute namedAttr) {
461   if (namedAttr.getName() == "test.invalid_attr")
462     return op->emitError() << "invalid to use 'test.invalid_attr'";
463   return success();
464 }
465 
466 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)467 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
468                                          unsigned resultIndex,
469                                          NamedAttribute namedAttr) {
470   if (namedAttr.getName() == "test.invalid_attr")
471     return op->emitError() << "invalid to use 'test.invalid_attr'";
472   return success();
473 }
474 
475 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const476 TestDialect::getParseOperationHook(StringRef opName) const {
477   if (opName == "test.dialect_custom_printer") {
478     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
479       return parser.parseKeyword("custom_format");
480     }};
481   }
482   if (opName == "test.dialect_custom_format_fallback") {
483     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
484       return parser.parseKeyword("custom_format_fallback");
485     }};
486   }
487   if (opName == "test.dialect_custom_printer.with.dot") {
488     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
489       return ParseResult::success();
490     }};
491   }
492   return None;
493 }
494 
495 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const496 TestDialect::getOperationPrinter(Operation *op) const {
497   StringRef opName = op->getName().getStringRef();
498   if (opName == "test.dialect_custom_printer") {
499     return [](Operation *op, OpAsmPrinter &printer) {
500       printer.getStream() << " custom_format";
501     };
502   }
503   if (opName == "test.dialect_custom_format_fallback") {
504     return [](Operation *op, OpAsmPrinter &printer) {
505       printer.getStream() << " custom_format_fallback";
506     };
507   }
508   return {};
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // TestBranchOp
513 //===----------------------------------------------------------------------===//
514 
getSuccessorOperands(unsigned index)515 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
516   assert(index == 0 && "invalid successor index");
517   return SuccessorOperands(getTargetOperandsMutable());
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // TestProducingBranchOp
522 //===----------------------------------------------------------------------===//
523 
getSuccessorOperands(unsigned index)524 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
525   assert(index <= 1 && "invalid successor index");
526   if (index == 1)
527     return SuccessorOperands(getFirstOperandsMutable());
528   return SuccessorOperands(getSecondOperandsMutable());
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // TestProducingBranchOp
533 //===----------------------------------------------------------------------===//
534 
getSuccessorOperands(unsigned index)535 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
536   assert(index <= 1 && "invalid successor index");
537   if (index == 0)
538     return SuccessorOperands(0, getSuccessOperandsMutable());
539   return SuccessorOperands(1, getErrorOperandsMutable());
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // TestDialectCanonicalizerOp
544 //===----------------------------------------------------------------------===//
545 
546 static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,PatternRewriter & rewriter)547 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
548                                PatternRewriter &rewriter) {
549   rewriter.replaceOpWithNewOp<arith::ConstantOp>(
550       op, rewriter.getI32IntegerAttr(42));
551   return success();
552 }
553 
getCanonicalizationPatterns(RewritePatternSet & results) const554 void TestDialect::getCanonicalizationPatterns(
555     RewritePatternSet &results) const {
556   results.add(&dialectCanonicalizationPattern);
557 }
558 
559 //===----------------------------------------------------------------------===//
560 // TestCallOp
561 //===----------------------------------------------------------------------===//
562 
verifySymbolUses(SymbolTableCollection & symbolTable)563 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
564   // Check that the callee attribute was specified.
565   auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
566   if (!fnAttr)
567     return emitOpError("requires a 'callee' symbol reference attribute");
568   if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
569     return emitOpError() << "'" << fnAttr.getValue()
570                          << "' does not reference a valid function";
571   return success();
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // TestFoldToCallOp
576 //===----------------------------------------------------------------------===//
577 
578 namespace {
579 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
580   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
581 
matchAndRewrite__anonf77d94721311::FoldToCallOpPattern582   LogicalResult matchAndRewrite(FoldToCallOp op,
583                                 PatternRewriter &rewriter) const override {
584     rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
585                                               op.getCalleeAttr(), ValueRange());
586     return success();
587   }
588 };
589 } // namespace
590 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)591 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
592                                                MLIRContext *context) {
593   results.add<FoldToCallOpPattern>(context);
594 }
595 
596 //===----------------------------------------------------------------------===//
597 // Test Format* operations
598 //===----------------------------------------------------------------------===//
599 
600 //===----------------------------------------------------------------------===//
601 // Parsing
602 
parseCustomOptionalOperand(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> & optOperand)603 static ParseResult parseCustomOptionalOperand(
604     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
605   if (succeeded(parser.parseOptionalLParen())) {
606     optOperand.emplace();
607     if (parser.parseOperand(*optOperand) || parser.parseRParen())
608       return failure();
609   }
610   return success();
611 }
612 
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::UnresolvedOperand & operand,Optional<OpAsmParser::UnresolvedOperand> & optOperand,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & varOperands)613 static ParseResult parseCustomDirectiveOperands(
614     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
615     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
616     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
617   if (parser.parseOperand(operand))
618     return failure();
619   if (succeeded(parser.parseOptionalComma())) {
620     optOperand.emplace();
621     if (parser.parseOperand(*optOperand))
622       return failure();
623   }
624   if (parser.parseArrow() || parser.parseLParen() ||
625       parser.parseOperandList(varOperands) || parser.parseRParen())
626     return failure();
627   return success();
628 }
629 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)630 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
631                             Type &optOperandType,
632                             SmallVectorImpl<Type> &varOperandTypes) {
633   if (parser.parseColon())
634     return failure();
635 
636   if (parser.parseType(operandType))
637     return failure();
638   if (succeeded(parser.parseOptionalComma())) {
639     if (parser.parseType(optOperandType))
640       return failure();
641   }
642   if (parser.parseArrow() || parser.parseLParen() ||
643       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
644     return failure();
645   return success();
646 }
647 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)648 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
649                                  Type optOperandType,
650                                  const SmallVectorImpl<Type> &varOperandTypes) {
651   if (parser.parseKeyword("type_refs_capture"))
652     return failure();
653 
654   Type operandType2, optOperandType2;
655   SmallVector<Type, 1> varOperandTypes2;
656   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
657                                   varOperandTypes2))
658     return failure();
659 
660   if (operandType != operandType2 || optOperandType != optOperandType2 ||
661       varOperandTypes != varOperandTypes2)
662     return failure();
663 
664   return success();
665 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::UnresolvedOperand & operand,Optional<OpAsmParser::UnresolvedOperand> & optOperand,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)666 static ParseResult parseCustomDirectiveOperandsAndTypes(
667     OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
668     Optional<OpAsmParser::UnresolvedOperand> &optOperand,
669     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
670     Type &operandType, Type &optOperandType,
671     SmallVectorImpl<Type> &varOperandTypes) {
672   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
673       parseCustomDirectiveResults(parser, operandType, optOperandType,
674                                   varOperandTypes))
675     return failure();
676   return success();
677 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)678 static ParseResult parseCustomDirectiveRegions(
679     OpAsmParser &parser, Region &region,
680     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
681   if (parser.parseRegion(region))
682     return failure();
683   if (failed(parser.parseOptionalComma()))
684     return success();
685   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
686   if (parser.parseRegion(*varRegion))
687     return failure();
688   varRegions.emplace_back(std::move(varRegion));
689   return success();
690 }
691 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)692 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
693                                SmallVectorImpl<Block *> &varSuccessors) {
694   if (parser.parseSuccessor(successor))
695     return failure();
696   if (failed(parser.parseOptionalComma()))
697     return success();
698   Block *varSuccessor;
699   if (parser.parseSuccessor(varSuccessor))
700     return failure();
701   varSuccessors.append(2, varSuccessor);
702   return success();
703 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)704 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
705                                                   IntegerAttr &attr,
706                                                   IntegerAttr &optAttr) {
707   if (parser.parseAttribute(attr))
708     return failure();
709   if (succeeded(parser.parseOptionalComma())) {
710     if (parser.parseAttribute(optAttr))
711       return failure();
712   }
713   return success();
714 }
715 
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)716 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
717                                                 NamedAttrList &attrs) {
718   return parser.parseOptionalAttrDict(attrs);
719 }
parseCustomDirectiveOptionalOperandRef(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> & optOperand)720 static ParseResult parseCustomDirectiveOptionalOperandRef(
721     OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
722   int64_t operandCount = 0;
723   if (parser.parseInteger(operandCount))
724     return failure();
725   bool expectedOptionalOperand = operandCount == 0;
726   return success(expectedOptionalOperand != optOperand.has_value());
727 }
728 
729 //===----------------------------------------------------------------------===//
730 // Printing
731 
printCustomOptionalOperand(OpAsmPrinter & printer,Operation *,Value optOperand)732 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
733                                        Value optOperand) {
734   if (optOperand)
735     printer << "(" << optOperand << ") ";
736 }
737 
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)738 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
739                                          Value operand, Value optOperand,
740                                          OperandRange varOperands) {
741   printer << operand;
742   if (optOperand)
743     printer << ", " << optOperand;
744   printer << " -> (" << varOperands << ")";
745 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)746 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
747                                         Type operandType, Type optOperandType,
748                                         TypeRange varOperandTypes) {
749   printer << " : " << operandType;
750   if (optOperandType)
751     printer << ", " << optOperandType;
752   printer << " -> (" << varOperandTypes << ")";
753 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)754 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
755                                              Operation *op, Type operandType,
756                                              Type optOperandType,
757                                              TypeRange varOperandTypes) {
758   printer << " type_refs_capture ";
759   printCustomDirectiveResults(printer, op, operandType, optOperandType,
760                               varOperandTypes);
761 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)762 static void printCustomDirectiveOperandsAndTypes(
763     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
764     OperandRange varOperands, Type operandType, Type optOperandType,
765     TypeRange varOperandTypes) {
766   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
767   printCustomDirectiveResults(printer, op, operandType, optOperandType,
768                               varOperandTypes);
769 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)770 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
771                                         Region &region,
772                                         MutableArrayRef<Region> varRegions) {
773   printer.printRegion(region);
774   if (!varRegions.empty()) {
775     printer << ", ";
776     for (Region &region : varRegions)
777       printer.printRegion(region);
778   }
779 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)780 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
781                                            Block *successor,
782                                            SuccessorRange varSuccessors) {
783   printer << successor;
784   if (!varSuccessors.empty())
785     printer << ", " << varSuccessors.front();
786 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)787 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
788                                            Attribute attribute,
789                                            Attribute optAttribute) {
790   printer << attribute;
791   if (optAttribute)
792     printer << ", " << optAttribute;
793 }
794 
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)795 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
796                                          DictionaryAttr attrs) {
797   printer.printOptionalAttrDict(attrs.getValue());
798 }
799 
printCustomDirectiveOptionalOperandRef(OpAsmPrinter & printer,Operation * op,Value optOperand)800 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
801                                                    Operation *op,
802                                                    Value optOperand) {
803   printer << (optOperand ? "1" : "0");
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // Test IsolatedRegionOp - parse passthrough region arguments.
808 //===----------------------------------------------------------------------===//
809 
parse(OpAsmParser & parser,OperationState & result)810 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
811                                     OperationState &result) {
812   // Parse the input operand.
813   OpAsmParser::Argument argInfo;
814   argInfo.type = parser.getBuilder().getIndexType();
815   if (parser.parseOperand(argInfo.ssaName) ||
816       parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
817     return failure();
818 
819   // Parse the body region, and reuse the operand info as the argument info.
820   Region *body = result.addRegion();
821   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
822 }
823 
print(OpAsmPrinter & p)824 void IsolatedRegionOp::print(OpAsmPrinter &p) {
825   p << "test.isolated_region ";
826   p.printOperand(getOperand());
827   p.shadowRegionArgs(getRegion(), getOperand());
828   p << ' ';
829   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // Test SSACFGRegionOp
834 //===----------------------------------------------------------------------===//
835 
getRegionKind(unsigned index)836 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
837   return RegionKind::SSACFG;
838 }
839 
840 //===----------------------------------------------------------------------===//
841 // Test GraphRegionOp
842 //===----------------------------------------------------------------------===//
843 
getRegionKind(unsigned index)844 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
845   return RegionKind::Graph;
846 }
847 
848 //===----------------------------------------------------------------------===//
849 // Test AffineScopeOp
850 //===----------------------------------------------------------------------===//
851 
parse(OpAsmParser & parser,OperationState & result)852 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
853   // Parse the body region, and reuse the operand info as the argument info.
854   Region *body = result.addRegion();
855   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
856 }
857 
print(OpAsmPrinter & p)858 void AffineScopeOp::print(OpAsmPrinter &p) {
859   p << "test.affine_scope ";
860   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
861 }
862 
863 //===----------------------------------------------------------------------===//
864 // Test parser.
865 //===----------------------------------------------------------------------===//
866 
parse(OpAsmParser & parser,OperationState & result)867 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
868                                          OperationState &result) {
869   if (parser.parseOptionalColon())
870     return success();
871   uint64_t numResults;
872   if (parser.parseInteger(numResults))
873     return failure();
874 
875   IndexType type = parser.getBuilder().getIndexType();
876   for (unsigned i = 0; i < numResults; ++i)
877     result.addTypes(type);
878   return success();
879 }
880 
print(OpAsmPrinter & p)881 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
882   if (unsigned numResults = getNumResults())
883     p << " : " << numResults;
884 }
885 
parse(OpAsmParser & parser,OperationState & result)886 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
887                                          OperationState &result) {
888   StringRef keyword;
889   if (parser.parseKeyword(&keyword))
890     return failure();
891   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
892   return success();
893 }
894 
print(OpAsmPrinter & p)895 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
896 
897 //===----------------------------------------------------------------------===//
898 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
899 
parse(OpAsmParser & parser,OperationState & result)900 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
901                                     OperationState &result) {
902   if (parser.parseKeyword("wraps"))
903     return failure();
904 
905   // Parse the wrapped op in a region
906   Region &body = *result.addRegion();
907   body.push_back(new Block);
908   Block &block = body.back();
909   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
910   if (!wrappedOp)
911     return failure();
912 
913   // Create a return terminator in the inner region, pass as operand to the
914   // terminator the returned values from the wrapped operation.
915   SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
916   OpBuilder builder(parser.getContext());
917   builder.setInsertionPointToEnd(&block);
918   builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
919 
920   // Get the results type for the wrapping op from the terminator operands.
921   Operation &returnOp = body.back().back();
922   result.types.append(returnOp.operand_type_begin(),
923                       returnOp.operand_type_end());
924 
925   // Use the location of the wrapped op for the "test.wrapping_region" op.
926   result.location = wrappedOp->getLoc();
927 
928   return success();
929 }
930 
print(OpAsmPrinter & p)931 void WrappingRegionOp::print(OpAsmPrinter &p) {
932   p << " wraps ";
933   p.printGenericOp(&getRegion().front().front());
934 }
935 
936 //===----------------------------------------------------------------------===//
937 // Test PrettyPrintedRegionOp -  exercising the following parser APIs
938 //   parseGenericOperationAfterOpName
939 //   parseCustomOperationName
940 //===----------------------------------------------------------------------===//
941 
parse(OpAsmParser & parser,OperationState & result)942 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
943                                          OperationState &result) {
944 
945   SMLoc loc = parser.getCurrentLocation();
946   Location currLocation = parser.getEncodedSourceLoc(loc);
947 
948   // Parse the operands.
949   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
950   if (parser.parseOperandList(operands))
951     return failure();
952 
953   // Check if we are parsing the pretty-printed version
954   //  test.pretty_printed_region start <inner-op> end : <functional-type>
955   // Else fallback to parsing the "non pretty-printed" version.
956   if (!succeeded(parser.parseOptionalKeyword("start")))
957     return parser.parseGenericOperationAfterOpName(
958         result, llvm::makeArrayRef(operands));
959 
960   FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
961   if (failed(parseOpNameInfo))
962     return failure();
963 
964   StringAttr innerOpName = parseOpNameInfo->getIdentifier();
965 
966   FunctionType opFntype;
967   Optional<Location> explicitLoc;
968   if (parser.parseKeyword("end") || parser.parseColon() ||
969       parser.parseType(opFntype) ||
970       parser.parseOptionalLocationSpecifier(explicitLoc))
971     return failure();
972 
973   // If location of the op is explicitly provided, then use it; Else use
974   // the parser's current location.
975   Location opLoc = explicitLoc.value_or(currLocation);
976 
977   // Derive the SSA-values for op's operands.
978   if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
979                              result.operands))
980     return failure();
981 
982   // Add a region for op.
983   Region &region = *result.addRegion();
984 
985   // Create a basic-block inside op's region.
986   Block &block = region.emplaceBlock();
987 
988   // Create and insert an "inner-op" operation in the block.
989   // Just for testing purposes, we can assume that inner op is a binary op with
990   // result and operand types all same as the test-op's first operand.
991   Type innerOpType = opFntype.getInput(0);
992   Value lhs = block.addArgument(innerOpType, opLoc);
993   Value rhs = block.addArgument(innerOpType, opLoc);
994 
995   OpBuilder builder(parser.getBuilder().getContext());
996   builder.setInsertionPointToStart(&block);
997 
998   Operation *innerOp =
999       builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
1000 
1001   // Insert a return statement in the block returning the inner-op's result.
1002   builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
1003 
1004   // Populate the op operation-state with result-type and location.
1005   result.addTypes(opFntype.getResults());
1006   result.location = innerOp->getLoc();
1007 
1008   return success();
1009 }
1010 
print(OpAsmPrinter & p)1011 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
1012   p << ' ';
1013   p.printOperands(getOperands());
1014 
1015   Operation &innerOp = getRegion().front().front();
1016   // Assuming that region has a single non-terminator inner-op, if the inner-op
1017   // meets some criteria (which in this case is a simple one  based on the name
1018   // of inner-op), then we can print the entire region in a succinct way.
1019   // Here we assume that the prototype of "special.op" can be trivially derived
1020   // while parsing it back.
1021   if (innerOp.getName().getStringRef().equals("special.op")) {
1022     p << " start special.op end";
1023   } else {
1024     p << " (";
1025     p.printRegion(getRegion());
1026     p << ")";
1027   }
1028 
1029   p << " : ";
1030   p.printFunctionalType(*this);
1031 }
1032 
1033 //===----------------------------------------------------------------------===//
1034 // Test PolyForOp - parse list of region arguments.
1035 //===----------------------------------------------------------------------===//
1036 
parse(OpAsmParser & parser,OperationState & result)1037 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
1038   SmallVector<OpAsmParser::Argument, 4> ivsInfo;
1039   // Parse list of region arguments without a delimiter.
1040   if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
1041     return failure();
1042 
1043   // Parse the body region.
1044   Region *body = result.addRegion();
1045   for (auto &iv : ivsInfo)
1046     iv.type = parser.getBuilder().getIndexType();
1047   return parser.parseRegion(*body, ivsInfo);
1048 }
1049 
print(OpAsmPrinter & p)1050 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
1051 
getAsmBlockArgumentNames(Region & region,OpAsmSetValueNameFn setNameFn)1052 void PolyForOp::getAsmBlockArgumentNames(Region &region,
1053                                          OpAsmSetValueNameFn setNameFn) {
1054   auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
1055   if (!arrayAttr)
1056     return;
1057   auto args = getRegion().front().getArguments();
1058   auto e = std::min(arrayAttr.size(), args.size());
1059   for (unsigned i = 0; i < e; ++i) {
1060     if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
1061       setNameFn(args[i], strAttr.getValue());
1062   }
1063 }
1064 
1065 //===----------------------------------------------------------------------===//
1066 // Test removing op with inner ops.
1067 //===----------------------------------------------------------------------===//
1068 
1069 namespace {
1070 struct TestRemoveOpWithInnerOps
1071     : public OpRewritePattern<TestOpWithRegionPattern> {
1072   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
1073 
initialize__anonf77d94721411::TestRemoveOpWithInnerOps1074   void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
1075 
matchAndRewrite__anonf77d94721411::TestRemoveOpWithInnerOps1076   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
1077                                 PatternRewriter &rewriter) const override {
1078     rewriter.eraseOp(op);
1079     return success();
1080   }
1081 };
1082 } // namespace
1083 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1084 void TestOpWithRegionPattern::getCanonicalizationPatterns(
1085     RewritePatternSet &results, MLIRContext *context) {
1086   results.add<TestRemoveOpWithInnerOps>(context);
1087 }
1088 
fold(ArrayRef<Attribute> operands)1089 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
1090   return getOperand();
1091 }
1092 
fold(ArrayRef<Attribute> operands)1093 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
1094   return getValue();
1095 }
1096 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1097 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
1098     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
1099   for (Value input : this->getOperands()) {
1100     results.push_back(input);
1101   }
1102   return success();
1103 }
1104 
fold(ArrayRef<Attribute> operands)1105 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
1106   assert(operands.size() == 1);
1107   if (operands.front()) {
1108     (*this)->setAttr("attr", operands.front());
1109     return getResult();
1110   }
1111   return {};
1112 }
1113 
fold(ArrayRef<Attribute> operands)1114 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
1115   return getOperand();
1116 }
1117 
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1118 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
1119     MLIRContext *, Optional<Location> location, ValueRange operands,
1120     DictionaryAttr attributes, RegionRange regions,
1121     SmallVectorImpl<Type> &inferredReturnTypes) {
1122   if (operands[0].getType() != operands[1].getType()) {
1123     return emitOptionalError(location, "operand type mismatch ",
1124                              operands[0].getType(), " vs ",
1125                              operands[1].getType());
1126   }
1127   inferredReturnTypes.assign({operands[0].getType()});
1128   return success();
1129 }
1130 
1131 // TODO: We should be able to only define either inferReturnType or
1132 // refineReturnType, currently only refineReturnType can be omitted.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & returnTypes)1133 LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
1134     MLIRContext *context, Optional<Location> location, ValueRange operands,
1135     DictionaryAttr attributes, RegionRange regions,
1136     SmallVectorImpl<Type> &returnTypes) {
1137   returnTypes.clear();
1138   return OpWithRefineTypeInterfaceOp::refineReturnTypes(
1139       context, location, operands, attributes, regions, returnTypes);
1140 }
1141 
refineReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & returnTypes)1142 LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
1143     MLIRContext *, Optional<Location> location, ValueRange operands,
1144     DictionaryAttr attributes, RegionRange regions,
1145     SmallVectorImpl<Type> &returnTypes) {
1146   if (operands[0].getType() != operands[1].getType()) {
1147     return emitOptionalError(location, "operand type mismatch ",
1148                              operands[0].getType(), " vs ",
1149                              operands[1].getType());
1150   }
1151   // TODO: Add helper to make this more concise to write.
1152   if (returnTypes.empty())
1153     returnTypes.resize(1, nullptr);
1154   if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1155     return emitOptionalError(location,
1156                              "required first operand and result to match");
1157   returnTypes[0] = operands[0].getType();
1158   return success();
1159 }
1160 
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1161 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
1162     MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1163     DictionaryAttr attributes, RegionRange regions,
1164     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1165   // Create return type consisting of the last element of the first operand.
1166   auto operandType = operands.front().getType();
1167   auto sval = operandType.dyn_cast<ShapedType>();
1168   if (!sval) {
1169     return emitOptionalError(location, "only shaped type operands allowed");
1170   }
1171   int64_t dim =
1172       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1173   auto type = IntegerType::get(context, 17);
1174   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1175   return success();
1176 }
1177 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)1178 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1179     OpBuilder &builder, ValueRange operands,
1180     llvm::SmallVectorImpl<Value> &shapes) {
1181   shapes = SmallVector<Value, 1>{
1182       builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1183   return success();
1184 }
1185 
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)1186 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1187     OpBuilder &builder, ValueRange operands,
1188     llvm::SmallVectorImpl<Value> &shapes) {
1189   Location loc = getLoc();
1190   shapes.reserve(operands.size());
1191   for (Value operand : llvm::reverse(operands)) {
1192     auto rank = operand.getType().cast<RankedTensorType>().getRank();
1193     auto currShape = llvm::to_vector<4>(
1194         llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1195           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1196         }));
1197     shapes.push_back(builder.create<tensor::FromElementsOp>(
1198         getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1199         currShape));
1200   }
1201   return success();
1202 }
1203 
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & shapes)1204 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1205     OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1206   Location loc = getLoc();
1207   shapes.reserve(getNumOperands());
1208   for (Value operand : llvm::reverse(getOperands())) {
1209     auto currShape = llvm::to_vector<4>(llvm::map_range(
1210         llvm::seq<int64_t>(
1211             0, operand.getType().cast<RankedTensorType>().getRank()),
1212         [&](int64_t dim) -> Value {
1213           return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1214         }));
1215     shapes.emplace_back(std::move(currShape));
1216   }
1217   return success();
1218 }
1219 
1220 //===----------------------------------------------------------------------===//
1221 // Test SideEffect interfaces
1222 //===----------------------------------------------------------------------===//
1223 
1224 namespace {
1225 /// A test resource for side effects.
1226 struct TestResource : public SideEffects::Resource::Base<TestResource> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonf77d94721711::TestResource1227   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1228 
1229   StringRef getName() final { return "<Test>"; }
1230 };
1231 } // namespace
1232 
testSideEffectOpGetEffect(Operation * op,SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> & effects)1233 static void testSideEffectOpGetEffect(
1234     Operation *op,
1235     SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1236         &effects) {
1237   auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1238   if (!effectsAttr)
1239     return;
1240 
1241   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1242 }
1243 
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)1244 void SideEffectOp::getEffects(
1245     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1246   // Check for an effects attribute on the op instance.
1247   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1248   if (!effectsAttr)
1249     return;
1250 
1251   // If there is one, it is an array of dictionary attributes that hold
1252   // information on the effects of this operation.
1253   for (Attribute element : effectsAttr) {
1254     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1255 
1256     // Get the specific memory effect.
1257     MemoryEffects::Effect *effect =
1258         StringSwitch<MemoryEffects::Effect *>(
1259             effectElement.get("effect").cast<StringAttr>().getValue())
1260             .Case("allocate", MemoryEffects::Allocate::get())
1261             .Case("free", MemoryEffects::Free::get())
1262             .Case("read", MemoryEffects::Read::get())
1263             .Case("write", MemoryEffects::Write::get());
1264 
1265     // Check for a non-default resource to use.
1266     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1267     if (effectElement.get("test_resource"))
1268       resource = TestResource::get();
1269 
1270     // Check for a result to affect.
1271     if (effectElement.get("on_result"))
1272       effects.emplace_back(effect, getResult(), resource);
1273     else if (Attribute ref = effectElement.get("on_reference"))
1274       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1275     else
1276       effects.emplace_back(effect, resource);
1277   }
1278 }
1279 
getEffects(SmallVectorImpl<TestEffects::EffectInstance> & effects)1280 void SideEffectOp::getEffects(
1281     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1282   testSideEffectOpGetEffect(getOperation(), effects);
1283 }
1284 
1285 //===----------------------------------------------------------------------===//
1286 // StringAttrPrettyNameOp
1287 //===----------------------------------------------------------------------===//
1288 
1289 // This op has fancy handling of its SSA result name.
parse(OpAsmParser & parser,OperationState & result)1290 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1291                                           OperationState &result) {
1292   // Add the result types.
1293   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1294     result.addTypes(parser.getBuilder().getIntegerType(32));
1295 
1296   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1297     return failure();
1298 
1299   // If the attribute dictionary contains no 'names' attribute, infer it from
1300   // the SSA name (if specified).
1301   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1302     return attr.getName() == "names";
1303   });
1304 
1305   // If there was no name specified, check to see if there was a useful name
1306   // specified in the asm file.
1307   if (hadNames || parser.getNumResults() == 0)
1308     return success();
1309 
1310   SmallVector<StringRef, 4> names;
1311   auto *context = result.getContext();
1312 
1313   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1314     auto resultName = parser.getResultName(i);
1315     StringRef nameStr;
1316     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1317       nameStr = resultName.first;
1318 
1319     names.push_back(nameStr);
1320   }
1321 
1322   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1323   result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1324   return success();
1325 }
1326 
print(OpAsmPrinter & p)1327 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1328   // Note that we only need to print the "name" attribute if the asmprinter
1329   // result name disagrees with it.  This can happen in strange cases, e.g.
1330   // when there are conflicts.
1331   bool namesDisagree = getNames().size() != getNumResults();
1332 
1333   SmallString<32> resultNameStr;
1334   for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1335     resultNameStr.clear();
1336     llvm::raw_svector_ostream tmpStream(resultNameStr);
1337     p.printOperand(getResult(i), tmpStream);
1338 
1339     auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1340     if (!expectedName ||
1341         tmpStream.str().drop_front() != expectedName.getValue()) {
1342       namesDisagree = true;
1343     }
1344   }
1345 
1346   if (namesDisagree)
1347     p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1348   else
1349     p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1350 }
1351 
1352 // We set the SSA name in the asm syntax to the contents of the name
1353 // attribute.
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1354 void StringAttrPrettyNameOp::getAsmResultNames(
1355     function_ref<void(Value, StringRef)> setNameFn) {
1356 
1357   auto value = getNames();
1358   for (size_t i = 0, e = value.size(); i != e; ++i)
1359     if (auto str = value[i].dyn_cast<StringAttr>())
1360       if (!str.getValue().empty())
1361         setNameFn(getResult(i), str.getValue());
1362 }
1363 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1364 void CustomResultsNameOp::getAsmResultNames(
1365     function_ref<void(Value, StringRef)> setNameFn) {
1366   ArrayAttr value = getNames();
1367   for (size_t i = 0, e = value.size(); i != e; ++i)
1368     if (auto str = value[i].dyn_cast<StringAttr>())
1369       if (!str.getValue().empty())
1370         setNameFn(getResult(i), str.getValue());
1371 }
1372 
1373 //===----------------------------------------------------------------------===//
1374 // ResultTypeWithTraitOp
1375 //===----------------------------------------------------------------------===//
1376 
verify()1377 LogicalResult ResultTypeWithTraitOp::verify() {
1378   if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1379     return success();
1380   return emitError("result type should have trait 'TestTypeTrait'");
1381 }
1382 
1383 //===----------------------------------------------------------------------===//
1384 // AttrWithTraitOp
1385 //===----------------------------------------------------------------------===//
1386 
verify()1387 LogicalResult AttrWithTraitOp::verify() {
1388   if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1389     return success();
1390   return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1391 }
1392 
1393 //===----------------------------------------------------------------------===//
1394 // RegionIfOp
1395 //===----------------------------------------------------------------------===//
1396 
print(OpAsmPrinter & p)1397 void RegionIfOp::print(OpAsmPrinter &p) {
1398   p << " ";
1399   p.printOperands(getOperands());
1400   p << ": " << getOperandTypes();
1401   p.printArrowTypeList(getResultTypes());
1402   p << " then ";
1403   p.printRegion(getThenRegion(),
1404                 /*printEntryBlockArgs=*/true,
1405                 /*printBlockTerminators=*/true);
1406   p << " else ";
1407   p.printRegion(getElseRegion(),
1408                 /*printEntryBlockArgs=*/true,
1409                 /*printBlockTerminators=*/true);
1410   p << " join ";
1411   p.printRegion(getJoinRegion(),
1412                 /*printEntryBlockArgs=*/true,
1413                 /*printBlockTerminators=*/true);
1414 }
1415 
parse(OpAsmParser & parser,OperationState & result)1416 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1417   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1418   SmallVector<Type, 2> operandTypes;
1419 
1420   result.regions.reserve(3);
1421   Region *thenRegion = result.addRegion();
1422   Region *elseRegion = result.addRegion();
1423   Region *joinRegion = result.addRegion();
1424 
1425   // Parse operand, type and arrow type lists.
1426   if (parser.parseOperandList(operandInfos) ||
1427       parser.parseColonTypeList(operandTypes) ||
1428       parser.parseArrowTypeList(result.types))
1429     return failure();
1430 
1431   // Parse all attached regions.
1432   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1433       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1434       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1435     return failure();
1436 
1437   return parser.resolveOperands(operandInfos, operandTypes,
1438                                 parser.getCurrentLocation(), result.operands);
1439 }
1440 
getSuccessorEntryOperands(Optional<unsigned> index)1441 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1442   assert(index && *index < 2 && "invalid region index");
1443   return getOperands();
1444 }
1445 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1446 void RegionIfOp::getSuccessorRegions(
1447     Optional<unsigned> index, ArrayRef<Attribute> operands,
1448     SmallVectorImpl<RegionSuccessor> &regions) {
1449   // We always branch to the join region.
1450   if (index.has_value()) {
1451     if (index.value() < 2)
1452       regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1453     else
1454       regions.push_back(RegionSuccessor(getResults()));
1455     return;
1456   }
1457 
1458   // The then and else regions are the entry regions of this op.
1459   regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1460   regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1461 }
1462 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & invocationBounds)1463 void RegionIfOp::getRegionInvocationBounds(
1464     ArrayRef<Attribute> operands,
1465     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1466   // Each region is invoked at most once.
1467   invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1468 }
1469 
1470 //===----------------------------------------------------------------------===//
1471 // AnyCondOp
1472 //===----------------------------------------------------------------------===//
1473 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1474 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1475                                     ArrayRef<Attribute> operands,
1476                                     SmallVectorImpl<RegionSuccessor> &regions) {
1477   // The parent op branches into the only region, and the region branches back
1478   // to the parent op.
1479   if (!index)
1480     regions.emplace_back(&getRegion());
1481   else
1482     regions.emplace_back(getResults());
1483 }
1484 
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & invocationBounds)1485 void AnyCondOp::getRegionInvocationBounds(
1486     ArrayRef<Attribute> operands,
1487     SmallVectorImpl<InvocationBounds> &invocationBounds) {
1488   invocationBounds.emplace_back(1, 1);
1489 }
1490 
1491 //===----------------------------------------------------------------------===//
1492 // SingleNoTerminatorCustomAsmOp
1493 //===----------------------------------------------------------------------===//
1494 
parse(OpAsmParser & parser,OperationState & state)1495 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1496                                                  OperationState &state) {
1497   Region *body = state.addRegion();
1498   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1499     return failure();
1500   return success();
1501 }
1502 
print(OpAsmPrinter & printer)1503 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1504   printer.printRegion(
1505       getRegion(), /*printEntryBlockArgs=*/false,
1506       // This op has a single block without terminators. But explicitly mark
1507       // as not printing block terminators for testing.
1508       /*printBlockTerminators=*/false);
1509 }
1510 
1511 //===----------------------------------------------------------------------===//
1512 // TestVerifiersOp
1513 //===----------------------------------------------------------------------===//
1514 
verify()1515 LogicalResult TestVerifiersOp::verify() {
1516   if (!getRegion().hasOneBlock())
1517     return emitOpError("`hasOneBlock` trait hasn't been verified");
1518 
1519   Operation *definingOp = getInput().getDefiningOp();
1520   if (definingOp && failed(mlir::verify(definingOp)))
1521     return emitOpError("operand hasn't been verified");
1522 
1523   emitRemark("success run of verifier");
1524 
1525   return success();
1526 }
1527 
verifyRegions()1528 LogicalResult TestVerifiersOp::verifyRegions() {
1529   if (!getRegion().hasOneBlock())
1530     return emitOpError("`hasOneBlock` trait hasn't been verified");
1531 
1532   for (Block &block : getRegion())
1533     for (Operation &op : block)
1534       if (failed(mlir::verify(&op)))
1535         return emitOpError("nested op hasn't been verified");
1536 
1537   emitRemark("success run of region verifier");
1538 
1539   return success();
1540 }
1541 
1542 //===----------------------------------------------------------------------===//
1543 // Test InferIntRangeInterface
1544 //===----------------------------------------------------------------------===//
1545 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1546 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1547                                          SetIntRangeFn setResultRanges) {
1548   setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
1549 }
1550 
parse(OpAsmParser & parser,OperationState & result)1551 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
1552                                           OperationState &result) {
1553   if (parser.parseOptionalAttrDict(result.attributes))
1554     return failure();
1555 
1556   // Parse the input argument
1557   OpAsmParser::Argument argInfo;
1558   argInfo.type = parser.getBuilder().getIndexType();
1559   if (failed(parser.parseArgument(argInfo)))
1560     return failure();
1561 
1562   // Parse the body region, and reuse the operand info as the argument info.
1563   Region *body = result.addRegion();
1564   return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
1565 }
1566 
print(OpAsmPrinter & p)1567 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
1568   p.printOptionalAttrDict((*this)->getAttrs());
1569   p << ' ';
1570   p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
1571                         /*omitType=*/true);
1572   p << ' ';
1573   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
1574 }
1575 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1576 void TestWithBoundsRegionOp::inferResultRanges(
1577     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1578   Value arg = getRegion().getArgument(0);
1579   setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
1580 }
1581 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1582 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1583                                         SetIntRangeFn setResultRanges) {
1584   const ConstantIntRanges &range = argRanges[0];
1585   APInt one(range.umin().getBitWidth(), 1);
1586   setResultRanges(getResult(),
1587                   {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
1588                    range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
1589 }
1590 
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1591 void TestReflectBoundsOp::inferResultRanges(
1592     ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1593   const ConstantIntRanges &range = argRanges[0];
1594   MLIRContext *ctx = getContext();
1595   Builder b(ctx);
1596   setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
1597   setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
1598   setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
1599   setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
1600   setResultRanges(getResult(), range);
1601 }
1602 
1603 #include "TestOpEnums.cpp.inc"
1604 #include "TestOpInterfaces.cpp.inc"
1605 #include "TestTypeInterfaces.cpp.inc"
1606 
1607 #define GET_OP_CLASSES
1608 #include "TestOps.cpp.inc"
1609