1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestInterfaces.h"
12 #include "TestTypes.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/DLTI/DLTI.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/Dialect/Tensor/IR/Tensor.h"
17 #include "mlir/IR/AsmState.h"
18 #include "mlir/IR/BuiltinAttributes.h"
19 #include "mlir/IR/BuiltinOps.h"
20 #include "mlir/IR/Diagnostics.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/ExtensibleDialect.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/IR/TypeUtilities.h"
27 #include "mlir/IR/Verifier.h"
28 #include "mlir/Interfaces/InferIntRangeInterface.h"
29 #include "mlir/Reducer/ReductionPatternInterface.h"
30 #include "mlir/Transforms/FoldUtils.h"
31 #include "mlir/Transforms/InliningUtils.h"
32 #include "llvm/ADT/SmallString.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/ADT/StringSwitch.h"
35
36 // Include this before the using namespace lines below to
37 // test that we don't have namespace dependencies.
38 #include "TestOpsDialect.cpp.inc"
39
40 using namespace mlir;
41 using namespace test;
42
registerTestDialect(DialectRegistry & registry)43 void test::registerTestDialect(DialectRegistry ®istry) {
44 registry.insert<TestDialect>();
45 }
46
47 //===----------------------------------------------------------------------===//
48 // External Elements Data
49 //===----------------------------------------------------------------------===//
50
getData() const51 ArrayRef<uint64_t> TestExternalElementsData::getData() const {
52 ArrayRef<char> data = AsmResourceBlob::getData();
53 return ArrayRef<uint64_t>((const uint64_t *)data.data(),
54 data.size() / sizeof(uint64_t));
55 }
56
57 TestExternalElementsData
allocate(size_t numElements)58 TestExternalElementsData::allocate(size_t numElements) {
59 return TestExternalElementsData(
60 llvm::ArrayRef<uint64_t>(new uint64_t[numElements], numElements),
61 [](const uint64_t *data, size_t) { delete[] data; },
62 /*dataIsMutable=*/true);
63 }
64
65 const TestExternalElementsData *
getData(StringRef name) const66 TestExternalElementsDataManager::getData(StringRef name) const {
67 auto it = dataMap.find(name);
68 return it != dataMap.end() ? &*it->second : nullptr;
69 }
70
71 std::pair<TestExternalElementsDataManager::DataMap::iterator, bool>
insert(StringRef name)72 TestExternalElementsDataManager::insert(StringRef name) {
73 auto it = dataMap.try_emplace(name, nullptr);
74 if (it.second)
75 return it;
76
77 llvm::SmallString<32> nameStorage(name);
78 nameStorage.push_back('_');
79 size_t nameCounter = 1;
80 do {
81 nameStorage += std::to_string(nameCounter++);
82 auto it = dataMap.try_emplace(nameStorage, nullptr);
83 if (it.second)
84 return it;
85 nameStorage.resize(name.size() + 1);
86 } while (true);
87 }
88
setData(StringRef name,TestExternalElementsData && data)89 void TestExternalElementsDataManager::setData(StringRef name,
90 TestExternalElementsData &&data) {
91 auto it = dataMap.find(name);
92 assert(it != dataMap.end() && "data not registered");
93 it->second = std::make_unique<TestExternalElementsData>(std::move(data));
94 }
95
96 //===----------------------------------------------------------------------===//
97 // TestDialect Interfaces
98 //===----------------------------------------------------------------------===//
99
100 namespace {
101
102 /// Testing the correctness of some traits.
103 static_assert(
104 llvm::is_detected<OpTrait::has_implicit_terminator_t,
105 SingleBlockImplicitTerminatorOp>::value,
106 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
107 static_assert(OpTrait::hasSingleBlockImplicitTerminator<
108 SingleBlockImplicitTerminatorOp>::value,
109 "hasSingleBlockImplicitTerminator does not match "
110 "SingleBlockImplicitTerminatorOp");
111
112 // Test support for interacting with the AsmPrinter.
113 struct TestOpAsmInterface : public OpAsmDialectInterface {
114 using OpAsmDialectInterface::OpAsmDialectInterface;
115
116 //===------------------------------------------------------------------===//
117 // Aliases
118 //===------------------------------------------------------------------===//
119
getAlias__anonf77d94720211::TestOpAsmInterface120 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
121 StringAttr strAttr = attr.dyn_cast<StringAttr>();
122 if (!strAttr)
123 return AliasResult::NoAlias;
124
125 // Check the contents of the string attribute to see what the test alias
126 // should be named.
127 Optional<StringRef> aliasName =
128 StringSwitch<Optional<StringRef>>(strAttr.getValue())
129 .Case("alias_test:dot_in_name", StringRef("test.alias"))
130 .Case("alias_test:trailing_digit", StringRef("test_alias0"))
131 .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
132 .Case("alias_test:sanitize_conflict_a",
133 StringRef("test_alias_conflict0"))
134 .Case("alias_test:sanitize_conflict_b",
135 StringRef("test_alias_conflict0_"))
136 .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
137 .Default(llvm::None);
138 if (!aliasName)
139 return AliasResult::NoAlias;
140
141 os << *aliasName;
142 return AliasResult::FinalAlias;
143 }
144
getAlias__anonf77d94720211::TestOpAsmInterface145 AliasResult getAlias(Type type, raw_ostream &os) const final {
146 if (auto tupleType = type.dyn_cast<TupleType>()) {
147 if (tupleType.size() > 0 &&
148 llvm::all_of(tupleType.getTypes(), [](Type elemType) {
149 return elemType.isa<SimpleAType>();
150 })) {
151 os << "test_tuple";
152 return AliasResult::FinalAlias;
153 }
154 }
155 if (auto intType = type.dyn_cast<TestIntegerType>()) {
156 if (intType.getSignedness() ==
157 TestIntegerType::SignednessSemantics::Unsigned &&
158 intType.getWidth() == 8) {
159 os << "test_ui8";
160 return AliasResult::FinalAlias;
161 }
162 }
163 if (auto recType = type.dyn_cast<TestRecursiveType>()) {
164 if (recType.getName() == "type_to_alias") {
165 // We only make alias for a specific recursive type.
166 os << "testrec";
167 return AliasResult::FinalAlias;
168 }
169 }
170 return AliasResult::NoAlias;
171 }
172
173 //===------------------------------------------------------------------===//
174 // Resources
175 //===------------------------------------------------------------------===//
176
177 std::string
getResourceKey__anonf77d94720211::TestOpAsmInterface178 getResourceKey(const AsmDialectResourceHandle &handle) const override {
179 return cast<TestExternalElementsDataHandle>(handle).getKey().str();
180 }
181
182 FailureOr<AsmDialectResourceHandle>
declareResource__anonf77d94720211::TestOpAsmInterface183 declareResource(StringRef key) const final {
184 TestDialect *dialect = cast<TestDialect>(getDialect());
185 TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
186
187 // Resolve the reference by inserting a new entry into the manager.
188 auto it = mgr.insert(key).first;
189 return TestExternalElementsDataHandle(&*it, dialect);
190 }
191
parseResource__anonf77d94720211::TestOpAsmInterface192 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
193 TestDialect *dialect = cast<TestDialect>(getDialect());
194 TestExternalElementsDataManager &mgr = dialect->getExternalDataManager();
195
196 // The resource entries are external constant data.
197 auto blobAllocFn = [](unsigned size, unsigned align) {
198 assert(align == alignof(uint64_t) && "unexpected data alignment");
199 return TestExternalElementsData::allocate(size / sizeof(uint64_t));
200 };
201 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob(blobAllocFn);
202 if (failed(blob))
203 return failure();
204
205 mgr.setData(entry.getKey(), std::move(*blob));
206 return success();
207 }
208
209 void
buildResources__anonf77d94720211::TestOpAsmInterface210 buildResources(Operation *op,
211 const SetVector<AsmDialectResourceHandle> &referencedResources,
212 AsmResourceBuilder &provider) const final {
213 for (const AsmDialectResourceHandle &handle : referencedResources) {
214 const auto &testHandle = cast<TestExternalElementsDataHandle>(handle);
215 provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData());
216 }
217 }
218 };
219
220 struct TestDialectFoldInterface : public DialectFoldInterface {
221 using DialectFoldInterface::DialectFoldInterface;
222
223 /// Registered hook to check if the given region, which is attached to an
224 /// operation that is *not* isolated from above, should be used when
225 /// materializing constants.
shouldMaterializeInto__anonf77d94720211::TestDialectFoldInterface226 bool shouldMaterializeInto(Region *region) const final {
227 // If this is a one region operation, then insert into it.
228 return isa<OneRegionOp>(region->getParentOp());
229 }
230 };
231
232 /// This class defines the interface for handling inlining with standard
233 /// operations.
234 struct TestInlinerInterface : public DialectInlinerInterface {
235 using DialectInlinerInterface::DialectInlinerInterface;
236
237 //===--------------------------------------------------------------------===//
238 // Analysis Hooks
239 //===--------------------------------------------------------------------===//
240
isLegalToInline__anonf77d94720211::TestInlinerInterface241 bool isLegalToInline(Operation *call, Operation *callable,
242 bool wouldBeCloned) const final {
243 // Don't allow inlining calls that are marked `noinline`.
244 return !call->hasAttr("noinline");
245 }
isLegalToInline__anonf77d94720211::TestInlinerInterface246 bool isLegalToInline(Region *, Region *, bool,
247 BlockAndValueMapping &) const final {
248 // Inlining into test dialect regions is legal.
249 return true;
250 }
isLegalToInline__anonf77d94720211::TestInlinerInterface251 bool isLegalToInline(Operation *, Region *, bool,
252 BlockAndValueMapping &) const final {
253 return true;
254 }
255
shouldAnalyzeRecursively__anonf77d94720211::TestInlinerInterface256 bool shouldAnalyzeRecursively(Operation *op) const final {
257 // Analyze recursively if this is not a functional region operation, it
258 // froms a separate functional scope.
259 return !isa<FunctionalRegionOp>(op);
260 }
261
262 //===--------------------------------------------------------------------===//
263 // Transformation Hooks
264 //===--------------------------------------------------------------------===//
265
266 /// Handle the given inlined terminator by replacing it with a new operation
267 /// as necessary.
handleTerminator__anonf77d94720211::TestInlinerInterface268 void handleTerminator(Operation *op,
269 ArrayRef<Value> valuesToRepl) const final {
270 // Only handle "test.return" here.
271 auto returnOp = dyn_cast<TestReturnOp>(op);
272 if (!returnOp)
273 return;
274
275 // Replace the values directly with the return operands.
276 assert(returnOp.getNumOperands() == valuesToRepl.size());
277 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
278 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
279 }
280
281 /// Attempt to materialize a conversion for a type mismatch between a call
282 /// from this dialect, and a callable region. This method should generate an
283 /// operation that takes 'input' as the only operand, and produces a single
284 /// result of 'resultType'. If a conversion can not be generated, nullptr
285 /// should be returned.
materializeCallConversion__anonf77d94720211::TestInlinerInterface286 Operation *materializeCallConversion(OpBuilder &builder, Value input,
287 Type resultType,
288 Location conversionLoc) const final {
289 // Only allow conversion for i16/i32 types.
290 if (!(resultType.isSignlessInteger(16) ||
291 resultType.isSignlessInteger(32)) ||
292 !(input.getType().isSignlessInteger(16) ||
293 input.getType().isSignlessInteger(32)))
294 return nullptr;
295 return builder.create<TestCastOp>(conversionLoc, resultType, input);
296 }
297
processInlinedCallBlocks__anonf77d94720211::TestInlinerInterface298 void processInlinedCallBlocks(
299 Operation *call,
300 iterator_range<Region::iterator> inlinedBlocks) const final {
301 if (!isa<ConversionCallOp>(call))
302 return;
303
304 // Set attributed on all ops in the inlined blocks.
305 for (Block &block : inlinedBlocks) {
306 block.walk([&](Operation *op) {
307 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
308 });
309 }
310 }
311 };
312
313 struct TestReductionPatternInterface : public DialectReductionPatternInterface {
314 public:
TestReductionPatternInterface__anonf77d94720211::TestReductionPatternInterface315 TestReductionPatternInterface(Dialect *dialect)
316 : DialectReductionPatternInterface(dialect) {}
317
populateReductionPatterns__anonf77d94720211::TestReductionPatternInterface318 void populateReductionPatterns(RewritePatternSet &patterns) const final {
319 populateTestReductionPatterns(patterns);
320 }
321 };
322
323 } // namespace
324
325 //===----------------------------------------------------------------------===//
326 // Dynamic operations
327 //===----------------------------------------------------------------------===//
328
getDynamicGenericOp(TestDialect * dialect)329 std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
330 return DynamicOpDefinition::get(
331 "dynamic_generic", dialect, [](Operation *op) { return success(); },
332 [](Operation *op) { return success(); });
333 }
334
335 std::unique_ptr<DynamicOpDefinition>
getDynamicOneOperandTwoResultsOp(TestDialect * dialect)336 getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
337 return DynamicOpDefinition::get(
338 "dynamic_one_operand_two_results", dialect,
339 [](Operation *op) {
340 if (op->getNumOperands() != 1) {
341 op->emitOpError()
342 << "expected 1 operand, but had " << op->getNumOperands();
343 return failure();
344 }
345 if (op->getNumResults() != 2) {
346 op->emitOpError()
347 << "expected 2 results, but had " << op->getNumResults();
348 return failure();
349 }
350 return success();
351 },
352 [](Operation *op) { return success(); });
353 }
354
355 std::unique_ptr<DynamicOpDefinition>
getDynamicCustomParserPrinterOp(TestDialect * dialect)356 getDynamicCustomParserPrinterOp(TestDialect *dialect) {
357 auto verifier = [](Operation *op) {
358 if (op->getNumOperands() == 0 && op->getNumResults() == 0)
359 return success();
360 op->emitError() << "operation should have no operands and no results";
361 return failure();
362 };
363 auto regionVerifier = [](Operation *op) { return success(); };
364
365 auto parser = [](OpAsmParser &parser, OperationState &state) {
366 return parser.parseKeyword("custom_keyword");
367 };
368
369 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
370 printer << op->getName() << " custom_keyword";
371 };
372
373 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
374 verifier, regionVerifier, parser, printer);
375 }
376
377 //===----------------------------------------------------------------------===//
378 // TestDialect
379 //===----------------------------------------------------------------------===//
380
381 static void testSideEffectOpGetEffect(
382 Operation *op,
383 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
384
385 // This is the implementation of a dialect fallback for `TestEffectOpInterface`.
386 struct TestOpEffectInterfaceFallback
387 : public TestEffectOpInterface::FallbackModel<
388 TestOpEffectInterfaceFallback> {
classofTestOpEffectInterfaceFallback389 static bool classof(Operation *op) {
390 bool isSupportedOp =
391 op->getName().getStringRef() == "test.unregistered_side_effect_op";
392 assert(isSupportedOp && "Unexpected dispatch");
393 return isSupportedOp;
394 }
395
396 void
getEffectsTestOpEffectInterfaceFallback397 getEffects(Operation *op,
398 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
399 &effects) const {
400 testSideEffectOpGetEffect(op, effects);
401 }
402 };
403
initialize()404 void TestDialect::initialize() {
405 registerAttributes();
406 registerTypes();
407 addOperations<
408 #define GET_OP_LIST
409 #include "TestOps.cpp.inc"
410 >();
411 registerDynamicOp(getDynamicGenericOp(this));
412 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
413 registerDynamicOp(getDynamicCustomParserPrinterOp(this));
414
415 addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
416 TestInlinerInterface, TestReductionPatternInterface>();
417 allowUnknownOperations();
418
419 // Instantiate our fallback op interface that we'll use on specific
420 // unregistered op.
421 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
422 }
~TestDialect()423 TestDialect::~TestDialect() {
424 delete static_cast<TestOpEffectInterfaceFallback *>(
425 fallbackEffectOpInterfaces);
426 }
427
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)428 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
429 Type type, Location loc) {
430 return builder.create<TestOpConstant>(loc, type, value);
431 }
432
inferReturnTypes(::mlir::MLIRContext * context,::llvm::Optional<::mlir::Location> location,::mlir::ValueRange operands,::mlir::DictionaryAttr attributes,::mlir::RegionRange regions,::llvm::SmallVectorImpl<::mlir::Type> & inferredReturnTypes)433 ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
434 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
435 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
436 ::mlir::RegionRange regions,
437 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
438 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
439 return ::mlir::success();
440 }
441
getRegisteredInterfaceForOp(TypeID typeID,OperationName opName)442 void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
443 OperationName opName) {
444 if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
445 typeID == TypeID::get<TestEffectOpInterface>())
446 return fallbackEffectOpInterfaces;
447 return nullptr;
448 }
449
verifyOperationAttribute(Operation * op,NamedAttribute namedAttr)450 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
451 NamedAttribute namedAttr) {
452 if (namedAttr.getName() == "test.invalid_attr")
453 return op->emitError() << "invalid to use 'test.invalid_attr'";
454 return success();
455 }
456
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute namedAttr)457 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
458 unsigned regionIndex,
459 unsigned argIndex,
460 NamedAttribute namedAttr) {
461 if (namedAttr.getName() == "test.invalid_attr")
462 return op->emitError() << "invalid to use 'test.invalid_attr'";
463 return success();
464 }
465
466 LogicalResult
verifyRegionResultAttribute(Operation * op,unsigned regionIndex,unsigned resultIndex,NamedAttribute namedAttr)467 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
468 unsigned resultIndex,
469 NamedAttribute namedAttr) {
470 if (namedAttr.getName() == "test.invalid_attr")
471 return op->emitError() << "invalid to use 'test.invalid_attr'";
472 return success();
473 }
474
475 Optional<Dialect::ParseOpHook>
getParseOperationHook(StringRef opName) const476 TestDialect::getParseOperationHook(StringRef opName) const {
477 if (opName == "test.dialect_custom_printer") {
478 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
479 return parser.parseKeyword("custom_format");
480 }};
481 }
482 if (opName == "test.dialect_custom_format_fallback") {
483 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
484 return parser.parseKeyword("custom_format_fallback");
485 }};
486 }
487 if (opName == "test.dialect_custom_printer.with.dot") {
488 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
489 return ParseResult::success();
490 }};
491 }
492 return None;
493 }
494
495 llvm::unique_function<void(Operation *, OpAsmPrinter &)>
getOperationPrinter(Operation * op) const496 TestDialect::getOperationPrinter(Operation *op) const {
497 StringRef opName = op->getName().getStringRef();
498 if (opName == "test.dialect_custom_printer") {
499 return [](Operation *op, OpAsmPrinter &printer) {
500 printer.getStream() << " custom_format";
501 };
502 }
503 if (opName == "test.dialect_custom_format_fallback") {
504 return [](Operation *op, OpAsmPrinter &printer) {
505 printer.getStream() << " custom_format_fallback";
506 };
507 }
508 return {};
509 }
510
511 //===----------------------------------------------------------------------===//
512 // TestBranchOp
513 //===----------------------------------------------------------------------===//
514
getSuccessorOperands(unsigned index)515 SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
516 assert(index == 0 && "invalid successor index");
517 return SuccessorOperands(getTargetOperandsMutable());
518 }
519
520 //===----------------------------------------------------------------------===//
521 // TestProducingBranchOp
522 //===----------------------------------------------------------------------===//
523
getSuccessorOperands(unsigned index)524 SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
525 assert(index <= 1 && "invalid successor index");
526 if (index == 1)
527 return SuccessorOperands(getFirstOperandsMutable());
528 return SuccessorOperands(getSecondOperandsMutable());
529 }
530
531 //===----------------------------------------------------------------------===//
532 // TestProducingBranchOp
533 //===----------------------------------------------------------------------===//
534
getSuccessorOperands(unsigned index)535 SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
536 assert(index <= 1 && "invalid successor index");
537 if (index == 0)
538 return SuccessorOperands(0, getSuccessOperandsMutable());
539 return SuccessorOperands(1, getErrorOperandsMutable());
540 }
541
542 //===----------------------------------------------------------------------===//
543 // TestDialectCanonicalizerOp
544 //===----------------------------------------------------------------------===//
545
546 static LogicalResult
dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,PatternRewriter & rewriter)547 dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
548 PatternRewriter &rewriter) {
549 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
550 op, rewriter.getI32IntegerAttr(42));
551 return success();
552 }
553
getCanonicalizationPatterns(RewritePatternSet & results) const554 void TestDialect::getCanonicalizationPatterns(
555 RewritePatternSet &results) const {
556 results.add(&dialectCanonicalizationPattern);
557 }
558
559 //===----------------------------------------------------------------------===//
560 // TestCallOp
561 //===----------------------------------------------------------------------===//
562
verifySymbolUses(SymbolTableCollection & symbolTable)563 LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
564 // Check that the callee attribute was specified.
565 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
566 if (!fnAttr)
567 return emitOpError("requires a 'callee' symbol reference attribute");
568 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
569 return emitOpError() << "'" << fnAttr.getValue()
570 << "' does not reference a valid function";
571 return success();
572 }
573
574 //===----------------------------------------------------------------------===//
575 // TestFoldToCallOp
576 //===----------------------------------------------------------------------===//
577
578 namespace {
579 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
580 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
581
matchAndRewrite__anonf77d94721311::FoldToCallOpPattern582 LogicalResult matchAndRewrite(FoldToCallOp op,
583 PatternRewriter &rewriter) const override {
584 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
585 op.getCalleeAttr(), ValueRange());
586 return success();
587 }
588 };
589 } // namespace
590
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)591 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
592 MLIRContext *context) {
593 results.add<FoldToCallOpPattern>(context);
594 }
595
596 //===----------------------------------------------------------------------===//
597 // Test Format* operations
598 //===----------------------------------------------------------------------===//
599
600 //===----------------------------------------------------------------------===//
601 // Parsing
602
parseCustomOptionalOperand(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> & optOperand)603 static ParseResult parseCustomOptionalOperand(
604 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
605 if (succeeded(parser.parseOptionalLParen())) {
606 optOperand.emplace();
607 if (parser.parseOperand(*optOperand) || parser.parseRParen())
608 return failure();
609 }
610 return success();
611 }
612
parseCustomDirectiveOperands(OpAsmParser & parser,OpAsmParser::UnresolvedOperand & operand,Optional<OpAsmParser::UnresolvedOperand> & optOperand,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & varOperands)613 static ParseResult parseCustomDirectiveOperands(
614 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
615 Optional<OpAsmParser::UnresolvedOperand> &optOperand,
616 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
617 if (parser.parseOperand(operand))
618 return failure();
619 if (succeeded(parser.parseOptionalComma())) {
620 optOperand.emplace();
621 if (parser.parseOperand(*optOperand))
622 return failure();
623 }
624 if (parser.parseArrow() || parser.parseLParen() ||
625 parser.parseOperandList(varOperands) || parser.parseRParen())
626 return failure();
627 return success();
628 }
629 static ParseResult
parseCustomDirectiveResults(OpAsmParser & parser,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)630 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
631 Type &optOperandType,
632 SmallVectorImpl<Type> &varOperandTypes) {
633 if (parser.parseColon())
634 return failure();
635
636 if (parser.parseType(operandType))
637 return failure();
638 if (succeeded(parser.parseOptionalComma())) {
639 if (parser.parseType(optOperandType))
640 return failure();
641 }
642 if (parser.parseArrow() || parser.parseLParen() ||
643 parser.parseTypeList(varOperandTypes) || parser.parseRParen())
644 return failure();
645 return success();
646 }
647 static ParseResult
parseCustomDirectiveWithTypeRefs(OpAsmParser & parser,Type operandType,Type optOperandType,const SmallVectorImpl<Type> & varOperandTypes)648 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
649 Type optOperandType,
650 const SmallVectorImpl<Type> &varOperandTypes) {
651 if (parser.parseKeyword("type_refs_capture"))
652 return failure();
653
654 Type operandType2, optOperandType2;
655 SmallVector<Type, 1> varOperandTypes2;
656 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
657 varOperandTypes2))
658 return failure();
659
660 if (operandType != operandType2 || optOperandType != optOperandType2 ||
661 varOperandTypes != varOperandTypes2)
662 return failure();
663
664 return success();
665 }
parseCustomDirectiveOperandsAndTypes(OpAsmParser & parser,OpAsmParser::UnresolvedOperand & operand,Optional<OpAsmParser::UnresolvedOperand> & optOperand,SmallVectorImpl<OpAsmParser::UnresolvedOperand> & varOperands,Type & operandType,Type & optOperandType,SmallVectorImpl<Type> & varOperandTypes)666 static ParseResult parseCustomDirectiveOperandsAndTypes(
667 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
668 Optional<OpAsmParser::UnresolvedOperand> &optOperand,
669 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
670 Type &operandType, Type &optOperandType,
671 SmallVectorImpl<Type> &varOperandTypes) {
672 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
673 parseCustomDirectiveResults(parser, operandType, optOperandType,
674 varOperandTypes))
675 return failure();
676 return success();
677 }
parseCustomDirectiveRegions(OpAsmParser & parser,Region & region,SmallVectorImpl<std::unique_ptr<Region>> & varRegions)678 static ParseResult parseCustomDirectiveRegions(
679 OpAsmParser &parser, Region ®ion,
680 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
681 if (parser.parseRegion(region))
682 return failure();
683 if (failed(parser.parseOptionalComma()))
684 return success();
685 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
686 if (parser.parseRegion(*varRegion))
687 return failure();
688 varRegions.emplace_back(std::move(varRegion));
689 return success();
690 }
691 static ParseResult
parseCustomDirectiveSuccessors(OpAsmParser & parser,Block * & successor,SmallVectorImpl<Block * > & varSuccessors)692 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
693 SmallVectorImpl<Block *> &varSuccessors) {
694 if (parser.parseSuccessor(successor))
695 return failure();
696 if (failed(parser.parseOptionalComma()))
697 return success();
698 Block *varSuccessor;
699 if (parser.parseSuccessor(varSuccessor))
700 return failure();
701 varSuccessors.append(2, varSuccessor);
702 return success();
703 }
parseCustomDirectiveAttributes(OpAsmParser & parser,IntegerAttr & attr,IntegerAttr & optAttr)704 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
705 IntegerAttr &attr,
706 IntegerAttr &optAttr) {
707 if (parser.parseAttribute(attr))
708 return failure();
709 if (succeeded(parser.parseOptionalComma())) {
710 if (parser.parseAttribute(optAttr))
711 return failure();
712 }
713 return success();
714 }
715
parseCustomDirectiveAttrDict(OpAsmParser & parser,NamedAttrList & attrs)716 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
717 NamedAttrList &attrs) {
718 return parser.parseOptionalAttrDict(attrs);
719 }
parseCustomDirectiveOptionalOperandRef(OpAsmParser & parser,Optional<OpAsmParser::UnresolvedOperand> & optOperand)720 static ParseResult parseCustomDirectiveOptionalOperandRef(
721 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
722 int64_t operandCount = 0;
723 if (parser.parseInteger(operandCount))
724 return failure();
725 bool expectedOptionalOperand = operandCount == 0;
726 return success(expectedOptionalOperand != optOperand.has_value());
727 }
728
729 //===----------------------------------------------------------------------===//
730 // Printing
731
printCustomOptionalOperand(OpAsmPrinter & printer,Operation *,Value optOperand)732 static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
733 Value optOperand) {
734 if (optOperand)
735 printer << "(" << optOperand << ") ";
736 }
737
printCustomDirectiveOperands(OpAsmPrinter & printer,Operation *,Value operand,Value optOperand,OperandRange varOperands)738 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
739 Value operand, Value optOperand,
740 OperandRange varOperands) {
741 printer << operand;
742 if (optOperand)
743 printer << ", " << optOperand;
744 printer << " -> (" << varOperands << ")";
745 }
printCustomDirectiveResults(OpAsmPrinter & printer,Operation *,Type operandType,Type optOperandType,TypeRange varOperandTypes)746 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
747 Type operandType, Type optOperandType,
748 TypeRange varOperandTypes) {
749 printer << " : " << operandType;
750 if (optOperandType)
751 printer << ", " << optOperandType;
752 printer << " -> (" << varOperandTypes << ")";
753 }
printCustomDirectiveWithTypeRefs(OpAsmPrinter & printer,Operation * op,Type operandType,Type optOperandType,TypeRange varOperandTypes)754 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
755 Operation *op, Type operandType,
756 Type optOperandType,
757 TypeRange varOperandTypes) {
758 printer << " type_refs_capture ";
759 printCustomDirectiveResults(printer, op, operandType, optOperandType,
760 varOperandTypes);
761 }
printCustomDirectiveOperandsAndTypes(OpAsmPrinter & printer,Operation * op,Value operand,Value optOperand,OperandRange varOperands,Type operandType,Type optOperandType,TypeRange varOperandTypes)762 static void printCustomDirectiveOperandsAndTypes(
763 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
764 OperandRange varOperands, Type operandType, Type optOperandType,
765 TypeRange varOperandTypes) {
766 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
767 printCustomDirectiveResults(printer, op, operandType, optOperandType,
768 varOperandTypes);
769 }
printCustomDirectiveRegions(OpAsmPrinter & printer,Operation *,Region & region,MutableArrayRef<Region> varRegions)770 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
771 Region ®ion,
772 MutableArrayRef<Region> varRegions) {
773 printer.printRegion(region);
774 if (!varRegions.empty()) {
775 printer << ", ";
776 for (Region ®ion : varRegions)
777 printer.printRegion(region);
778 }
779 }
printCustomDirectiveSuccessors(OpAsmPrinter & printer,Operation *,Block * successor,SuccessorRange varSuccessors)780 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
781 Block *successor,
782 SuccessorRange varSuccessors) {
783 printer << successor;
784 if (!varSuccessors.empty())
785 printer << ", " << varSuccessors.front();
786 }
printCustomDirectiveAttributes(OpAsmPrinter & printer,Operation *,Attribute attribute,Attribute optAttribute)787 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
788 Attribute attribute,
789 Attribute optAttribute) {
790 printer << attribute;
791 if (optAttribute)
792 printer << ", " << optAttribute;
793 }
794
printCustomDirectiveAttrDict(OpAsmPrinter & printer,Operation * op,DictionaryAttr attrs)795 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
796 DictionaryAttr attrs) {
797 printer.printOptionalAttrDict(attrs.getValue());
798 }
799
printCustomDirectiveOptionalOperandRef(OpAsmPrinter & printer,Operation * op,Value optOperand)800 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
801 Operation *op,
802 Value optOperand) {
803 printer << (optOperand ? "1" : "0");
804 }
805
806 //===----------------------------------------------------------------------===//
807 // Test IsolatedRegionOp - parse passthrough region arguments.
808 //===----------------------------------------------------------------------===//
809
parse(OpAsmParser & parser,OperationState & result)810 ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
811 OperationState &result) {
812 // Parse the input operand.
813 OpAsmParser::Argument argInfo;
814 argInfo.type = parser.getBuilder().getIndexType();
815 if (parser.parseOperand(argInfo.ssaName) ||
816 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
817 return failure();
818
819 // Parse the body region, and reuse the operand info as the argument info.
820 Region *body = result.addRegion();
821 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
822 }
823
print(OpAsmPrinter & p)824 void IsolatedRegionOp::print(OpAsmPrinter &p) {
825 p << "test.isolated_region ";
826 p.printOperand(getOperand());
827 p.shadowRegionArgs(getRegion(), getOperand());
828 p << ' ';
829 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
830 }
831
832 //===----------------------------------------------------------------------===//
833 // Test SSACFGRegionOp
834 //===----------------------------------------------------------------------===//
835
getRegionKind(unsigned index)836 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
837 return RegionKind::SSACFG;
838 }
839
840 //===----------------------------------------------------------------------===//
841 // Test GraphRegionOp
842 //===----------------------------------------------------------------------===//
843
getRegionKind(unsigned index)844 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
845 return RegionKind::Graph;
846 }
847
848 //===----------------------------------------------------------------------===//
849 // Test AffineScopeOp
850 //===----------------------------------------------------------------------===//
851
parse(OpAsmParser & parser,OperationState & result)852 ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
853 // Parse the body region, and reuse the operand info as the argument info.
854 Region *body = result.addRegion();
855 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
856 }
857
print(OpAsmPrinter & p)858 void AffineScopeOp::print(OpAsmPrinter &p) {
859 p << "test.affine_scope ";
860 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
861 }
862
863 //===----------------------------------------------------------------------===//
864 // Test parser.
865 //===----------------------------------------------------------------------===//
866
parse(OpAsmParser & parser,OperationState & result)867 ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
868 OperationState &result) {
869 if (parser.parseOptionalColon())
870 return success();
871 uint64_t numResults;
872 if (parser.parseInteger(numResults))
873 return failure();
874
875 IndexType type = parser.getBuilder().getIndexType();
876 for (unsigned i = 0; i < numResults; ++i)
877 result.addTypes(type);
878 return success();
879 }
880
print(OpAsmPrinter & p)881 void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
882 if (unsigned numResults = getNumResults())
883 p << " : " << numResults;
884 }
885
parse(OpAsmParser & parser,OperationState & result)886 ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
887 OperationState &result) {
888 StringRef keyword;
889 if (parser.parseKeyword(&keyword))
890 return failure();
891 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
892 return success();
893 }
894
print(OpAsmPrinter & p)895 void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
896
897 //===----------------------------------------------------------------------===//
898 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
899
parse(OpAsmParser & parser,OperationState & result)900 ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
901 OperationState &result) {
902 if (parser.parseKeyword("wraps"))
903 return failure();
904
905 // Parse the wrapped op in a region
906 Region &body = *result.addRegion();
907 body.push_back(new Block);
908 Block &block = body.back();
909 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
910 if (!wrappedOp)
911 return failure();
912
913 // Create a return terminator in the inner region, pass as operand to the
914 // terminator the returned values from the wrapped operation.
915 SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
916 OpBuilder builder(parser.getContext());
917 builder.setInsertionPointToEnd(&block);
918 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
919
920 // Get the results type for the wrapping op from the terminator operands.
921 Operation &returnOp = body.back().back();
922 result.types.append(returnOp.operand_type_begin(),
923 returnOp.operand_type_end());
924
925 // Use the location of the wrapped op for the "test.wrapping_region" op.
926 result.location = wrappedOp->getLoc();
927
928 return success();
929 }
930
print(OpAsmPrinter & p)931 void WrappingRegionOp::print(OpAsmPrinter &p) {
932 p << " wraps ";
933 p.printGenericOp(&getRegion().front().front());
934 }
935
936 //===----------------------------------------------------------------------===//
937 // Test PrettyPrintedRegionOp - exercising the following parser APIs
938 // parseGenericOperationAfterOpName
939 // parseCustomOperationName
940 //===----------------------------------------------------------------------===//
941
parse(OpAsmParser & parser,OperationState & result)942 ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
943 OperationState &result) {
944
945 SMLoc loc = parser.getCurrentLocation();
946 Location currLocation = parser.getEncodedSourceLoc(loc);
947
948 // Parse the operands.
949 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
950 if (parser.parseOperandList(operands))
951 return failure();
952
953 // Check if we are parsing the pretty-printed version
954 // test.pretty_printed_region start <inner-op> end : <functional-type>
955 // Else fallback to parsing the "non pretty-printed" version.
956 if (!succeeded(parser.parseOptionalKeyword("start")))
957 return parser.parseGenericOperationAfterOpName(
958 result, llvm::makeArrayRef(operands));
959
960 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
961 if (failed(parseOpNameInfo))
962 return failure();
963
964 StringAttr innerOpName = parseOpNameInfo->getIdentifier();
965
966 FunctionType opFntype;
967 Optional<Location> explicitLoc;
968 if (parser.parseKeyword("end") || parser.parseColon() ||
969 parser.parseType(opFntype) ||
970 parser.parseOptionalLocationSpecifier(explicitLoc))
971 return failure();
972
973 // If location of the op is explicitly provided, then use it; Else use
974 // the parser's current location.
975 Location opLoc = explicitLoc.value_or(currLocation);
976
977 // Derive the SSA-values for op's operands.
978 if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
979 result.operands))
980 return failure();
981
982 // Add a region for op.
983 Region ®ion = *result.addRegion();
984
985 // Create a basic-block inside op's region.
986 Block &block = region.emplaceBlock();
987
988 // Create and insert an "inner-op" operation in the block.
989 // Just for testing purposes, we can assume that inner op is a binary op with
990 // result and operand types all same as the test-op's first operand.
991 Type innerOpType = opFntype.getInput(0);
992 Value lhs = block.addArgument(innerOpType, opLoc);
993 Value rhs = block.addArgument(innerOpType, opLoc);
994
995 OpBuilder builder(parser.getBuilder().getContext());
996 builder.setInsertionPointToStart(&block);
997
998 Operation *innerOp =
999 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
1000
1001 // Insert a return statement in the block returning the inner-op's result.
1002 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
1003
1004 // Populate the op operation-state with result-type and location.
1005 result.addTypes(opFntype.getResults());
1006 result.location = innerOp->getLoc();
1007
1008 return success();
1009 }
1010
print(OpAsmPrinter & p)1011 void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
1012 p << ' ';
1013 p.printOperands(getOperands());
1014
1015 Operation &innerOp = getRegion().front().front();
1016 // Assuming that region has a single non-terminator inner-op, if the inner-op
1017 // meets some criteria (which in this case is a simple one based on the name
1018 // of inner-op), then we can print the entire region in a succinct way.
1019 // Here we assume that the prototype of "special.op" can be trivially derived
1020 // while parsing it back.
1021 if (innerOp.getName().getStringRef().equals("special.op")) {
1022 p << " start special.op end";
1023 } else {
1024 p << " (";
1025 p.printRegion(getRegion());
1026 p << ")";
1027 }
1028
1029 p << " : ";
1030 p.printFunctionalType(*this);
1031 }
1032
1033 //===----------------------------------------------------------------------===//
1034 // Test PolyForOp - parse list of region arguments.
1035 //===----------------------------------------------------------------------===//
1036
parse(OpAsmParser & parser,OperationState & result)1037 ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
1038 SmallVector<OpAsmParser::Argument, 4> ivsInfo;
1039 // Parse list of region arguments without a delimiter.
1040 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
1041 return failure();
1042
1043 // Parse the body region.
1044 Region *body = result.addRegion();
1045 for (auto &iv : ivsInfo)
1046 iv.type = parser.getBuilder().getIndexType();
1047 return parser.parseRegion(*body, ivsInfo);
1048 }
1049
print(OpAsmPrinter & p)1050 void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
1051
getAsmBlockArgumentNames(Region & region,OpAsmSetValueNameFn setNameFn)1052 void PolyForOp::getAsmBlockArgumentNames(Region ®ion,
1053 OpAsmSetValueNameFn setNameFn) {
1054 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
1055 if (!arrayAttr)
1056 return;
1057 auto args = getRegion().front().getArguments();
1058 auto e = std::min(arrayAttr.size(), args.size());
1059 for (unsigned i = 0; i < e; ++i) {
1060 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
1061 setNameFn(args[i], strAttr.getValue());
1062 }
1063 }
1064
1065 //===----------------------------------------------------------------------===//
1066 // Test removing op with inner ops.
1067 //===----------------------------------------------------------------------===//
1068
1069 namespace {
1070 struct TestRemoveOpWithInnerOps
1071 : public OpRewritePattern<TestOpWithRegionPattern> {
1072 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
1073
initialize__anonf77d94721411::TestRemoveOpWithInnerOps1074 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
1075
matchAndRewrite__anonf77d94721411::TestRemoveOpWithInnerOps1076 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
1077 PatternRewriter &rewriter) const override {
1078 rewriter.eraseOp(op);
1079 return success();
1080 }
1081 };
1082 } // namespace
1083
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1084 void TestOpWithRegionPattern::getCanonicalizationPatterns(
1085 RewritePatternSet &results, MLIRContext *context) {
1086 results.add<TestRemoveOpWithInnerOps>(context);
1087 }
1088
fold(ArrayRef<Attribute> operands)1089 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
1090 return getOperand();
1091 }
1092
fold(ArrayRef<Attribute> operands)1093 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
1094 return getValue();
1095 }
1096
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1097 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
1098 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
1099 for (Value input : this->getOperands()) {
1100 results.push_back(input);
1101 }
1102 return success();
1103 }
1104
fold(ArrayRef<Attribute> operands)1105 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
1106 assert(operands.size() == 1);
1107 if (operands.front()) {
1108 (*this)->setAttr("attr", operands.front());
1109 return getResult();
1110 }
1111 return {};
1112 }
1113
fold(ArrayRef<Attribute> operands)1114 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
1115 return getOperand();
1116 }
1117
inferReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & inferredReturnTypes)1118 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
1119 MLIRContext *, Optional<Location> location, ValueRange operands,
1120 DictionaryAttr attributes, RegionRange regions,
1121 SmallVectorImpl<Type> &inferredReturnTypes) {
1122 if (operands[0].getType() != operands[1].getType()) {
1123 return emitOptionalError(location, "operand type mismatch ",
1124 operands[0].getType(), " vs ",
1125 operands[1].getType());
1126 }
1127 inferredReturnTypes.assign({operands[0].getType()});
1128 return success();
1129 }
1130
1131 // TODO: We should be able to only define either inferReturnType or
1132 // refineReturnType, currently only refineReturnType can be omitted.
inferReturnTypes(MLIRContext * context,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & returnTypes)1133 LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
1134 MLIRContext *context, Optional<Location> location, ValueRange operands,
1135 DictionaryAttr attributes, RegionRange regions,
1136 SmallVectorImpl<Type> &returnTypes) {
1137 returnTypes.clear();
1138 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
1139 context, location, operands, attributes, regions, returnTypes);
1140 }
1141
refineReturnTypes(MLIRContext *,Optional<Location> location,ValueRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<Type> & returnTypes)1142 LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
1143 MLIRContext *, Optional<Location> location, ValueRange operands,
1144 DictionaryAttr attributes, RegionRange regions,
1145 SmallVectorImpl<Type> &returnTypes) {
1146 if (operands[0].getType() != operands[1].getType()) {
1147 return emitOptionalError(location, "operand type mismatch ",
1148 operands[0].getType(), " vs ",
1149 operands[1].getType());
1150 }
1151 // TODO: Add helper to make this more concise to write.
1152 if (returnTypes.empty())
1153 returnTypes.resize(1, nullptr);
1154 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1155 return emitOptionalError(location,
1156 "required first operand and result to match");
1157 returnTypes[0] = operands[0].getType();
1158 return success();
1159 }
1160
inferReturnTypeComponents(MLIRContext * context,Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1161 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
1162 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1163 DictionaryAttr attributes, RegionRange regions,
1164 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1165 // Create return type consisting of the last element of the first operand.
1166 auto operandType = operands.front().getType();
1167 auto sval = operandType.dyn_cast<ShapedType>();
1168 if (!sval) {
1169 return emitOptionalError(location, "only shaped type operands allowed");
1170 }
1171 int64_t dim =
1172 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1173 auto type = IntegerType::get(context, 17);
1174 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1175 return success();
1176 }
1177
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)1178 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1179 OpBuilder &builder, ValueRange operands,
1180 llvm::SmallVectorImpl<Value> &shapes) {
1181 shapes = SmallVector<Value, 1>{
1182 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1183 return success();
1184 }
1185
reifyReturnTypeShapes(OpBuilder & builder,ValueRange operands,llvm::SmallVectorImpl<Value> & shapes)1186 LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1187 OpBuilder &builder, ValueRange operands,
1188 llvm::SmallVectorImpl<Value> &shapes) {
1189 Location loc = getLoc();
1190 shapes.reserve(operands.size());
1191 for (Value operand : llvm::reverse(operands)) {
1192 auto rank = operand.getType().cast<RankedTensorType>().getRank();
1193 auto currShape = llvm::to_vector<4>(
1194 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1195 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1196 }));
1197 shapes.push_back(builder.create<tensor::FromElementsOp>(
1198 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1199 currShape));
1200 }
1201 return success();
1202 }
1203
reifyResultShapes(OpBuilder & builder,ReifiedRankedShapedTypeDims & shapes)1204 LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1205 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1206 Location loc = getLoc();
1207 shapes.reserve(getNumOperands());
1208 for (Value operand : llvm::reverse(getOperands())) {
1209 auto currShape = llvm::to_vector<4>(llvm::map_range(
1210 llvm::seq<int64_t>(
1211 0, operand.getType().cast<RankedTensorType>().getRank()),
1212 [&](int64_t dim) -> Value {
1213 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1214 }));
1215 shapes.emplace_back(std::move(currShape));
1216 }
1217 return success();
1218 }
1219
1220 //===----------------------------------------------------------------------===//
1221 // Test SideEffect interfaces
1222 //===----------------------------------------------------------------------===//
1223
1224 namespace {
1225 /// A test resource for side effects.
1226 struct TestResource : public SideEffects::Resource::Base<TestResource> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonf77d94721711::TestResource1227 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)
1228
1229 StringRef getName() final { return "<Test>"; }
1230 };
1231 } // namespace
1232
testSideEffectOpGetEffect(Operation * op,SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> & effects)1233 static void testSideEffectOpGetEffect(
1234 Operation *op,
1235 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1236 &effects) {
1237 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1238 if (!effectsAttr)
1239 return;
1240
1241 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1242 }
1243
getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> & effects)1244 void SideEffectOp::getEffects(
1245 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1246 // Check for an effects attribute on the op instance.
1247 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1248 if (!effectsAttr)
1249 return;
1250
1251 // If there is one, it is an array of dictionary attributes that hold
1252 // information on the effects of this operation.
1253 for (Attribute element : effectsAttr) {
1254 DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1255
1256 // Get the specific memory effect.
1257 MemoryEffects::Effect *effect =
1258 StringSwitch<MemoryEffects::Effect *>(
1259 effectElement.get("effect").cast<StringAttr>().getValue())
1260 .Case("allocate", MemoryEffects::Allocate::get())
1261 .Case("free", MemoryEffects::Free::get())
1262 .Case("read", MemoryEffects::Read::get())
1263 .Case("write", MemoryEffects::Write::get());
1264
1265 // Check for a non-default resource to use.
1266 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1267 if (effectElement.get("test_resource"))
1268 resource = TestResource::get();
1269
1270 // Check for a result to affect.
1271 if (effectElement.get("on_result"))
1272 effects.emplace_back(effect, getResult(), resource);
1273 else if (Attribute ref = effectElement.get("on_reference"))
1274 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1275 else
1276 effects.emplace_back(effect, resource);
1277 }
1278 }
1279
getEffects(SmallVectorImpl<TestEffects::EffectInstance> & effects)1280 void SideEffectOp::getEffects(
1281 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1282 testSideEffectOpGetEffect(getOperation(), effects);
1283 }
1284
1285 //===----------------------------------------------------------------------===//
1286 // StringAttrPrettyNameOp
1287 //===----------------------------------------------------------------------===//
1288
1289 // This op has fancy handling of its SSA result name.
parse(OpAsmParser & parser,OperationState & result)1290 ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1291 OperationState &result) {
1292 // Add the result types.
1293 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1294 result.addTypes(parser.getBuilder().getIntegerType(32));
1295
1296 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1297 return failure();
1298
1299 // If the attribute dictionary contains no 'names' attribute, infer it from
1300 // the SSA name (if specified).
1301 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1302 return attr.getName() == "names";
1303 });
1304
1305 // If there was no name specified, check to see if there was a useful name
1306 // specified in the asm file.
1307 if (hadNames || parser.getNumResults() == 0)
1308 return success();
1309
1310 SmallVector<StringRef, 4> names;
1311 auto *context = result.getContext();
1312
1313 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1314 auto resultName = parser.getResultName(i);
1315 StringRef nameStr;
1316 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1317 nameStr = resultName.first;
1318
1319 names.push_back(nameStr);
1320 }
1321
1322 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1323 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1324 return success();
1325 }
1326
print(OpAsmPrinter & p)1327 void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1328 // Note that we only need to print the "name" attribute if the asmprinter
1329 // result name disagrees with it. This can happen in strange cases, e.g.
1330 // when there are conflicts.
1331 bool namesDisagree = getNames().size() != getNumResults();
1332
1333 SmallString<32> resultNameStr;
1334 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1335 resultNameStr.clear();
1336 llvm::raw_svector_ostream tmpStream(resultNameStr);
1337 p.printOperand(getResult(i), tmpStream);
1338
1339 auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1340 if (!expectedName ||
1341 tmpStream.str().drop_front() != expectedName.getValue()) {
1342 namesDisagree = true;
1343 }
1344 }
1345
1346 if (namesDisagree)
1347 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1348 else
1349 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1350 }
1351
1352 // We set the SSA name in the asm syntax to the contents of the name
1353 // attribute.
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1354 void StringAttrPrettyNameOp::getAsmResultNames(
1355 function_ref<void(Value, StringRef)> setNameFn) {
1356
1357 auto value = getNames();
1358 for (size_t i = 0, e = value.size(); i != e; ++i)
1359 if (auto str = value[i].dyn_cast<StringAttr>())
1360 if (!str.getValue().empty())
1361 setNameFn(getResult(i), str.getValue());
1362 }
1363
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)1364 void CustomResultsNameOp::getAsmResultNames(
1365 function_ref<void(Value, StringRef)> setNameFn) {
1366 ArrayAttr value = getNames();
1367 for (size_t i = 0, e = value.size(); i != e; ++i)
1368 if (auto str = value[i].dyn_cast<StringAttr>())
1369 if (!str.getValue().empty())
1370 setNameFn(getResult(i), str.getValue());
1371 }
1372
1373 //===----------------------------------------------------------------------===//
1374 // ResultTypeWithTraitOp
1375 //===----------------------------------------------------------------------===//
1376
verify()1377 LogicalResult ResultTypeWithTraitOp::verify() {
1378 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1379 return success();
1380 return emitError("result type should have trait 'TestTypeTrait'");
1381 }
1382
1383 //===----------------------------------------------------------------------===//
1384 // AttrWithTraitOp
1385 //===----------------------------------------------------------------------===//
1386
verify()1387 LogicalResult AttrWithTraitOp::verify() {
1388 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1389 return success();
1390 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1391 }
1392
1393 //===----------------------------------------------------------------------===//
1394 // RegionIfOp
1395 //===----------------------------------------------------------------------===//
1396
print(OpAsmPrinter & p)1397 void RegionIfOp::print(OpAsmPrinter &p) {
1398 p << " ";
1399 p.printOperands(getOperands());
1400 p << ": " << getOperandTypes();
1401 p.printArrowTypeList(getResultTypes());
1402 p << " then ";
1403 p.printRegion(getThenRegion(),
1404 /*printEntryBlockArgs=*/true,
1405 /*printBlockTerminators=*/true);
1406 p << " else ";
1407 p.printRegion(getElseRegion(),
1408 /*printEntryBlockArgs=*/true,
1409 /*printBlockTerminators=*/true);
1410 p << " join ";
1411 p.printRegion(getJoinRegion(),
1412 /*printEntryBlockArgs=*/true,
1413 /*printBlockTerminators=*/true);
1414 }
1415
parse(OpAsmParser & parser,OperationState & result)1416 ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1417 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1418 SmallVector<Type, 2> operandTypes;
1419
1420 result.regions.reserve(3);
1421 Region *thenRegion = result.addRegion();
1422 Region *elseRegion = result.addRegion();
1423 Region *joinRegion = result.addRegion();
1424
1425 // Parse operand, type and arrow type lists.
1426 if (parser.parseOperandList(operandInfos) ||
1427 parser.parseColonTypeList(operandTypes) ||
1428 parser.parseArrowTypeList(result.types))
1429 return failure();
1430
1431 // Parse all attached regions.
1432 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1433 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1434 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1435 return failure();
1436
1437 return parser.resolveOperands(operandInfos, operandTypes,
1438 parser.getCurrentLocation(), result.operands);
1439 }
1440
getSuccessorEntryOperands(Optional<unsigned> index)1441 OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1442 assert(index && *index < 2 && "invalid region index");
1443 return getOperands();
1444 }
1445
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1446 void RegionIfOp::getSuccessorRegions(
1447 Optional<unsigned> index, ArrayRef<Attribute> operands,
1448 SmallVectorImpl<RegionSuccessor> ®ions) {
1449 // We always branch to the join region.
1450 if (index.has_value()) {
1451 if (index.value() < 2)
1452 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1453 else
1454 regions.push_back(RegionSuccessor(getResults()));
1455 return;
1456 }
1457
1458 // The then and else regions are the entry regions of this op.
1459 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1460 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1461 }
1462
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & invocationBounds)1463 void RegionIfOp::getRegionInvocationBounds(
1464 ArrayRef<Attribute> operands,
1465 SmallVectorImpl<InvocationBounds> &invocationBounds) {
1466 // Each region is invoked at most once.
1467 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1468 }
1469
1470 //===----------------------------------------------------------------------===//
1471 // AnyCondOp
1472 //===----------------------------------------------------------------------===//
1473
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1474 void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1475 ArrayRef<Attribute> operands,
1476 SmallVectorImpl<RegionSuccessor> ®ions) {
1477 // The parent op branches into the only region, and the region branches back
1478 // to the parent op.
1479 if (!index)
1480 regions.emplace_back(&getRegion());
1481 else
1482 regions.emplace_back(getResults());
1483 }
1484
getRegionInvocationBounds(ArrayRef<Attribute> operands,SmallVectorImpl<InvocationBounds> & invocationBounds)1485 void AnyCondOp::getRegionInvocationBounds(
1486 ArrayRef<Attribute> operands,
1487 SmallVectorImpl<InvocationBounds> &invocationBounds) {
1488 invocationBounds.emplace_back(1, 1);
1489 }
1490
1491 //===----------------------------------------------------------------------===//
1492 // SingleNoTerminatorCustomAsmOp
1493 //===----------------------------------------------------------------------===//
1494
parse(OpAsmParser & parser,OperationState & state)1495 ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1496 OperationState &state) {
1497 Region *body = state.addRegion();
1498 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1499 return failure();
1500 return success();
1501 }
1502
print(OpAsmPrinter & printer)1503 void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1504 printer.printRegion(
1505 getRegion(), /*printEntryBlockArgs=*/false,
1506 // This op has a single block without terminators. But explicitly mark
1507 // as not printing block terminators for testing.
1508 /*printBlockTerminators=*/false);
1509 }
1510
1511 //===----------------------------------------------------------------------===//
1512 // TestVerifiersOp
1513 //===----------------------------------------------------------------------===//
1514
verify()1515 LogicalResult TestVerifiersOp::verify() {
1516 if (!getRegion().hasOneBlock())
1517 return emitOpError("`hasOneBlock` trait hasn't been verified");
1518
1519 Operation *definingOp = getInput().getDefiningOp();
1520 if (definingOp && failed(mlir::verify(definingOp)))
1521 return emitOpError("operand hasn't been verified");
1522
1523 emitRemark("success run of verifier");
1524
1525 return success();
1526 }
1527
verifyRegions()1528 LogicalResult TestVerifiersOp::verifyRegions() {
1529 if (!getRegion().hasOneBlock())
1530 return emitOpError("`hasOneBlock` trait hasn't been verified");
1531
1532 for (Block &block : getRegion())
1533 for (Operation &op : block)
1534 if (failed(mlir::verify(&op)))
1535 return emitOpError("nested op hasn't been verified");
1536
1537 emitRemark("success run of region verifier");
1538
1539 return success();
1540 }
1541
1542 //===----------------------------------------------------------------------===//
1543 // Test InferIntRangeInterface
1544 //===----------------------------------------------------------------------===//
1545
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1546 void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1547 SetIntRangeFn setResultRanges) {
1548 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
1549 }
1550
parse(OpAsmParser & parser,OperationState & result)1551 ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
1552 OperationState &result) {
1553 if (parser.parseOptionalAttrDict(result.attributes))
1554 return failure();
1555
1556 // Parse the input argument
1557 OpAsmParser::Argument argInfo;
1558 argInfo.type = parser.getBuilder().getIndexType();
1559 if (failed(parser.parseArgument(argInfo)))
1560 return failure();
1561
1562 // Parse the body region, and reuse the operand info as the argument info.
1563 Region *body = result.addRegion();
1564 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
1565 }
1566
print(OpAsmPrinter & p)1567 void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
1568 p.printOptionalAttrDict((*this)->getAttrs());
1569 p << ' ';
1570 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
1571 /*omitType=*/true);
1572 p << ' ';
1573 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
1574 }
1575
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1576 void TestWithBoundsRegionOp::inferResultRanges(
1577 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1578 Value arg = getRegion().getArgument(0);
1579 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
1580 }
1581
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1582 void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1583 SetIntRangeFn setResultRanges) {
1584 const ConstantIntRanges &range = argRanges[0];
1585 APInt one(range.umin().getBitWidth(), 1);
1586 setResultRanges(getResult(),
1587 {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
1588 range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
1589 }
1590
inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,SetIntRangeFn setResultRanges)1591 void TestReflectBoundsOp::inferResultRanges(
1592 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1593 const ConstantIntRanges &range = argRanges[0];
1594 MLIRContext *ctx = getContext();
1595 Builder b(ctx);
1596 setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
1597 setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
1598 setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
1599 setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
1600 setResultRanges(getResult(), range);
1601 }
1602
1603 #include "TestOpEnums.cpp.inc"
1604 #include "TestOpInterfaces.cpp.inc"
1605 #include "TestTypeInterfaces.cpp.inc"
1606
1607 #define GET_OP_CLASSES
1608 #include "TestOps.cpp.inc"
1609