1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/StringSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::test;
22 
23 void mlir::test::registerTestDialect(DialectRegistry &registry) {
24   registry.insert<TestDialect>();
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // TestDialect Interfaces
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 
33 // Test support for interacting with the AsmPrinter.
34 struct TestOpAsmInterface : public OpAsmDialectInterface {
35   using OpAsmDialectInterface::OpAsmDialectInterface;
36 
37   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
38     StringAttr strAttr = attr.dyn_cast<StringAttr>();
39     if (!strAttr)
40       return failure();
41 
42     // Check the contents of the string attribute to see what the test alias
43     // should be named.
44     Optional<StringRef> aliasName =
45         StringSwitch<Optional<StringRef>>(strAttr.getValue())
46             .Case("alias_test:dot_in_name", StringRef("test.alias"))
47             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
48             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
49             .Case("alias_test:sanitize_conflict_a",
50                   StringRef("test_alias_conflict0"))
51             .Case("alias_test:sanitize_conflict_b",
52                   StringRef("test_alias_conflict0_"))
53             .Default(llvm::None);
54     if (!aliasName)
55       return failure();
56 
57     os << *aliasName;
58     return success();
59   }
60 
61   void getAsmResultNames(Operation *op,
62                          OpAsmSetValueNameFn setNameFn) const final {
63     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
64       setNameFn(asmOp, "result");
65   }
66 
67   void getAsmBlockArgumentNames(Block *block,
68                                 OpAsmSetValueNameFn setNameFn) const final {
69     auto op = block->getParentOp();
70     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
71     if (!arrayAttr)
72       return;
73     auto args = block->getArguments();
74     auto e = std::min(arrayAttr.size(), args.size());
75     for (unsigned i = 0; i < e; ++i) {
76       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
77         setNameFn(args[i], strAttr.getValue());
78     }
79   }
80 };
81 
82 struct TestDialectFoldInterface : public DialectFoldInterface {
83   using DialectFoldInterface::DialectFoldInterface;
84 
85   /// Registered hook to check if the given region, which is attached to an
86   /// operation that is *not* isolated from above, should be used when
87   /// materializing constants.
88   bool shouldMaterializeInto(Region *region) const final {
89     // If this is a one region operation, then insert into it.
90     return isa<OneRegionOp>(region->getParentOp());
91   }
92 };
93 
94 /// This class defines the interface for handling inlining with standard
95 /// operations.
96 struct TestInlinerInterface : public DialectInlinerInterface {
97   using DialectInlinerInterface::DialectInlinerInterface;
98 
99   //===--------------------------------------------------------------------===//
100   // Analysis Hooks
101   //===--------------------------------------------------------------------===//
102 
103   bool isLegalToInline(Operation *call, Operation *callable,
104                        bool wouldBeCloned) const final {
105     // Don't allow inlining calls that are marked `noinline`.
106     return !call->hasAttr("noinline");
107   }
108   bool isLegalToInline(Region *, Region *, bool,
109                        BlockAndValueMapping &) const final {
110     // Inlining into test dialect regions is legal.
111     return true;
112   }
113   bool isLegalToInline(Operation *, Region *, bool,
114                        BlockAndValueMapping &) const final {
115     return true;
116   }
117 
118   bool shouldAnalyzeRecursively(Operation *op) const final {
119     // Analyze recursively if this is not a functional region operation, it
120     // froms a separate functional scope.
121     return !isa<FunctionalRegionOp>(op);
122   }
123 
124   //===--------------------------------------------------------------------===//
125   // Transformation Hooks
126   //===--------------------------------------------------------------------===//
127 
128   /// Handle the given inlined terminator by replacing it with a new operation
129   /// as necessary.
130   void handleTerminator(Operation *op,
131                         ArrayRef<Value> valuesToRepl) const final {
132     // Only handle "test.return" here.
133     auto returnOp = dyn_cast<TestReturnOp>(op);
134     if (!returnOp)
135       return;
136 
137     // Replace the values directly with the return operands.
138     assert(returnOp.getNumOperands() == valuesToRepl.size());
139     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
140       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
141   }
142 
143   /// Attempt to materialize a conversion for a type mismatch between a call
144   /// from this dialect, and a callable region. This method should generate an
145   /// operation that takes 'input' as the only operand, and produces a single
146   /// result of 'resultType'. If a conversion can not be generated, nullptr
147   /// should be returned.
148   Operation *materializeCallConversion(OpBuilder &builder, Value input,
149                                        Type resultType,
150                                        Location conversionLoc) const final {
151     // Only allow conversion for i16/i32 types.
152     if (!(resultType.isSignlessInteger(16) ||
153           resultType.isSignlessInteger(32)) ||
154         !(input.getType().isSignlessInteger(16) ||
155           input.getType().isSignlessInteger(32)))
156       return nullptr;
157     return builder.create<TestCastOp>(conversionLoc, resultType, input);
158   }
159 };
160 } // end anonymous namespace
161 
162 //===----------------------------------------------------------------------===//
163 // TestDialect
164 //===----------------------------------------------------------------------===//
165 
166 void TestDialect::initialize() {
167   addOperations<
168 #define GET_OP_LIST
169 #include "TestOps.cpp.inc"
170       >();
171   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
172                 TestInlinerInterface>();
173   addTypes<TestType, TestRecursiveType,
174 #define GET_TYPEDEF_LIST
175 #include "TestTypeDefs.cpp.inc"
176            >();
177   allowUnknownOperations();
178 }
179 
180 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
181                                             Type type, Location loc) {
182   return builder.create<TestOpConstant>(loc, type, value);
183 }
184 
185 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
186                                                     NamedAttribute namedAttr) {
187   if (namedAttr.first == "test.invalid_attr")
188     return op->emitError() << "invalid to use 'test.invalid_attr'";
189   return success();
190 }
191 
192 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
193                                                     unsigned regionIndex,
194                                                     unsigned argIndex,
195                                                     NamedAttribute namedAttr) {
196   if (namedAttr.first == "test.invalid_attr")
197     return op->emitError() << "invalid to use 'test.invalid_attr'";
198   return success();
199 }
200 
201 LogicalResult
202 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
203                                          unsigned resultIndex,
204                                          NamedAttribute namedAttr) {
205   if (namedAttr.first == "test.invalid_attr")
206     return op->emitError() << "invalid to use 'test.invalid_attr'";
207   return success();
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // TestBranchOp
212 //===----------------------------------------------------------------------===//
213 
214 Optional<MutableOperandRange>
215 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
216   assert(index == 0 && "invalid successor index");
217   return targetOperandsMutable();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // TestFoldToCallOp
222 //===----------------------------------------------------------------------===//
223 
224 namespace {
225 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
226   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
227 
228   LogicalResult matchAndRewrite(FoldToCallOp op,
229                                 PatternRewriter &rewriter) const override {
230     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
231                                         ValueRange());
232     return success();
233   }
234 };
235 } // end anonymous namespace
236 
237 void FoldToCallOp::getCanonicalizationPatterns(
238     OwningRewritePatternList &results, MLIRContext *context) {
239   results.insert<FoldToCallOpPattern>(context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Test Format* operations
244 //===----------------------------------------------------------------------===//
245 
246 //===----------------------------------------------------------------------===//
247 // Parsing
248 
249 static ParseResult parseCustomDirectiveOperands(
250     OpAsmParser &parser, OpAsmParser::OperandType &operand,
251     Optional<OpAsmParser::OperandType> &optOperand,
252     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
253   if (parser.parseOperand(operand))
254     return failure();
255   if (succeeded(parser.parseOptionalComma())) {
256     optOperand.emplace();
257     if (parser.parseOperand(*optOperand))
258       return failure();
259   }
260   if (parser.parseArrow() || parser.parseLParen() ||
261       parser.parseOperandList(varOperands) || parser.parseRParen())
262     return failure();
263   return success();
264 }
265 static ParseResult
266 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
267                             Type &optOperandType,
268                             SmallVectorImpl<Type> &varOperandTypes) {
269   if (parser.parseColon())
270     return failure();
271 
272   if (parser.parseType(operandType))
273     return failure();
274   if (succeeded(parser.parseOptionalComma())) {
275     if (parser.parseType(optOperandType))
276       return failure();
277   }
278   if (parser.parseArrow() || parser.parseLParen() ||
279       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
280     return failure();
281   return success();
282 }
283 static ParseResult
284 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
285                                  Type optOperandType,
286                                  const SmallVectorImpl<Type> &varOperandTypes) {
287   if (parser.parseKeyword("type_refs_capture"))
288     return failure();
289 
290   Type operandType2, optOperandType2;
291   SmallVector<Type, 1> varOperandTypes2;
292   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
293                                   varOperandTypes2))
294     return failure();
295 
296   if (operandType != operandType2 || optOperandType != optOperandType2 ||
297       varOperandTypes != varOperandTypes2)
298     return failure();
299 
300   return success();
301 }
302 static ParseResult parseCustomDirectiveOperandsAndTypes(
303     OpAsmParser &parser, OpAsmParser::OperandType &operand,
304     Optional<OpAsmParser::OperandType> &optOperand,
305     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
306     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
307   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
308       parseCustomDirectiveResults(parser, operandType, optOperandType,
309                                   varOperandTypes))
310     return failure();
311   return success();
312 }
313 static ParseResult parseCustomDirectiveRegions(
314     OpAsmParser &parser, Region &region,
315     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
316   if (parser.parseRegion(region))
317     return failure();
318   if (failed(parser.parseOptionalComma()))
319     return success();
320   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
321   if (parser.parseRegion(*varRegion))
322     return failure();
323   varRegions.emplace_back(std::move(varRegion));
324   return success();
325 }
326 static ParseResult
327 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
328                                SmallVectorImpl<Block *> &varSuccessors) {
329   if (parser.parseSuccessor(successor))
330     return failure();
331   if (failed(parser.parseOptionalComma()))
332     return success();
333   Block *varSuccessor;
334   if (parser.parseSuccessor(varSuccessor))
335     return failure();
336   varSuccessors.append(2, varSuccessor);
337   return success();
338 }
339 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
340                                                   IntegerAttr &attr,
341                                                   IntegerAttr &optAttr) {
342   if (parser.parseAttribute(attr))
343     return failure();
344   if (succeeded(parser.parseOptionalComma())) {
345     if (parser.parseAttribute(optAttr))
346       return failure();
347   }
348   return success();
349 }
350 
351 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
352                                                 NamedAttrList &attrs) {
353   return parser.parseOptionalAttrDict(attrs);
354 }
355 
356 //===----------------------------------------------------------------------===//
357 // Printing
358 
359 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
360                                          Value operand, Value optOperand,
361                                          OperandRange varOperands) {
362   printer << operand;
363   if (optOperand)
364     printer << ", " << optOperand;
365   printer << " -> (" << varOperands << ")";
366 }
367 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
368                                         Type operandType, Type optOperandType,
369                                         TypeRange varOperandTypes) {
370   printer << " : " << operandType;
371   if (optOperandType)
372     printer << ", " << optOperandType;
373   printer << " -> (" << varOperandTypes << ")";
374 }
375 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
376                                              Operation *op, Type operandType,
377                                              Type optOperandType,
378                                              TypeRange varOperandTypes) {
379   printer << " type_refs_capture ";
380   printCustomDirectiveResults(printer, op, operandType, optOperandType,
381                               varOperandTypes);
382 }
383 static void printCustomDirectiveOperandsAndTypes(
384     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
385     OperandRange varOperands, Type operandType, Type optOperandType,
386     TypeRange varOperandTypes) {
387   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
388   printCustomDirectiveResults(printer, op, operandType, optOperandType,
389                               varOperandTypes);
390 }
391 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
392                                         Region &region,
393                                         MutableArrayRef<Region> varRegions) {
394   printer.printRegion(region);
395   if (!varRegions.empty()) {
396     printer << ", ";
397     for (Region &region : varRegions)
398       printer.printRegion(region);
399   }
400 }
401 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
402                                            Block *successor,
403                                            SuccessorRange varSuccessors) {
404   printer << successor;
405   if (!varSuccessors.empty())
406     printer << ", " << varSuccessors.front();
407 }
408 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
409                                            Attribute attribute,
410                                            Attribute optAttribute) {
411   printer << attribute;
412   if (optAttribute)
413     printer << ", " << optAttribute;
414 }
415 
416 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
417                                          MutableDictionaryAttr attrs) {
418   printer.printOptionalAttrDict(attrs.getAttrs());
419 }
420 //===----------------------------------------------------------------------===//
421 // Test IsolatedRegionOp - parse passthrough region arguments.
422 //===----------------------------------------------------------------------===//
423 
424 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
425                                          OperationState &result) {
426   OpAsmParser::OperandType argInfo;
427   Type argType = parser.getBuilder().getIndexType();
428 
429   // Parse the input operand.
430   if (parser.parseOperand(argInfo) ||
431       parser.resolveOperand(argInfo, argType, result.operands))
432     return failure();
433 
434   // Parse the body region, and reuse the operand info as the argument info.
435   Region *body = result.addRegion();
436   return parser.parseRegion(*body, argInfo, argType,
437                             /*enableNameShadowing=*/true);
438 }
439 
440 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
441   p << "test.isolated_region ";
442   p.printOperand(op.getOperand());
443   p.shadowRegionArgs(op.region(), op.getOperand());
444   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
445 }
446 
447 //===----------------------------------------------------------------------===//
448 // Test SSACFGRegionOp
449 //===----------------------------------------------------------------------===//
450 
451 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
452   return RegionKind::SSACFG;
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // Test GraphRegionOp
457 //===----------------------------------------------------------------------===//
458 
459 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
460                                       OperationState &result) {
461   // Parse the body region, and reuse the operand info as the argument info.
462   Region *body = result.addRegion();
463   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
464 }
465 
466 static void print(OpAsmPrinter &p, GraphRegionOp op) {
467   p << "test.graph_region ";
468   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
469 }
470 
471 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
472   return RegionKind::Graph;
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Test AffineScopeOp
477 //===----------------------------------------------------------------------===//
478 
479 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
480                                       OperationState &result) {
481   // Parse the body region, and reuse the operand info as the argument info.
482   Region *body = result.addRegion();
483   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
484 }
485 
486 static void print(OpAsmPrinter &p, AffineScopeOp op) {
487   p << "test.affine_scope ";
488   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // Test parser.
493 //===----------------------------------------------------------------------===//
494 
495 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
496                                               OperationState &result) {
497   if (parser.parseOptionalColon())
498     return success();
499   uint64_t numResults;
500   if (parser.parseInteger(numResults))
501     return failure();
502 
503   IndexType type = parser.getBuilder().getIndexType();
504   for (unsigned i = 0; i < numResults; ++i)
505     result.addTypes(type);
506   return success();
507 }
508 
509 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
510   p << ParseIntegerLiteralOp::getOperationName();
511   if (unsigned numResults = op->getNumResults())
512     p << " : " << numResults;
513 }
514 
515 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
516                                               OperationState &result) {
517   StringRef keyword;
518   if (parser.parseKeyword(&keyword))
519     return failure();
520   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
521   return success();
522 }
523 
524 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
525   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
530 
531 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
532                                          OperationState &result) {
533   if (parser.parseKeyword("wraps"))
534     return failure();
535 
536   // Parse the wrapped op in a region
537   Region &body = *result.addRegion();
538   body.push_back(new Block);
539   Block &block = body.back();
540   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
541   if (!wrapped_op)
542     return failure();
543 
544   // Create a return terminator in the inner region, pass as operand to the
545   // terminator the returned values from the wrapped operation.
546   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
547   OpBuilder builder(parser.getBuilder().getContext());
548   builder.setInsertionPointToEnd(&block);
549   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
550 
551   // Get the results type for the wrapping op from the terminator operands.
552   Operation &return_op = body.back().back();
553   result.types.append(return_op.operand_type_begin(),
554                       return_op.operand_type_end());
555 
556   // Use the location of the wrapped op for the "test.wrapping_region" op.
557   result.location = wrapped_op->getLoc();
558 
559   return success();
560 }
561 
562 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
563   p << op.getOperationName() << " wraps ";
564   p.printGenericOp(&op.region().front().front());
565 }
566 
567 //===----------------------------------------------------------------------===//
568 // Test PolyForOp - parse list of region arguments.
569 //===----------------------------------------------------------------------===//
570 
571 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
572   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
573   // Parse list of region arguments without a delimiter.
574   if (parser.parseRegionArgumentList(ivsInfo))
575     return failure();
576 
577   // Parse the body region.
578   Region *body = result.addRegion();
579   auto &builder = parser.getBuilder();
580   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
581   return parser.parseRegion(*body, ivsInfo, argTypes);
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // Test removing op with inner ops.
586 //===----------------------------------------------------------------------===//
587 
588 namespace {
589 struct TestRemoveOpWithInnerOps
590     : public OpRewritePattern<TestOpWithRegionPattern> {
591   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
592 
593   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
594                                 PatternRewriter &rewriter) const override {
595     rewriter.eraseOp(op);
596     return success();
597   }
598 };
599 } // end anonymous namespace
600 
601 void TestOpWithRegionPattern::getCanonicalizationPatterns(
602     OwningRewritePatternList &results, MLIRContext *context) {
603   results.insert<TestRemoveOpWithInnerOps>(context);
604 }
605 
606 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
607   return operand();
608 }
609 
610 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
611   return getValue();
612 }
613 
614 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
615     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
616   for (Value input : this->operands()) {
617     results.push_back(input);
618   }
619   return success();
620 }
621 
622 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
623   assert(operands.size() == 1);
624   if (operands.front()) {
625     (*this)->setAttr("attr", operands.front());
626     return getResult();
627   }
628   return {};
629 }
630 
631 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
632     MLIRContext *, Optional<Location> location, ValueRange operands,
633     DictionaryAttr attributes, RegionRange regions,
634     SmallVectorImpl<Type> &inferredReturnTypes) {
635   if (operands[0].getType() != operands[1].getType()) {
636     return emitOptionalError(location, "operand type mismatch ",
637                              operands[0].getType(), " vs ",
638                              operands[1].getType());
639   }
640   inferredReturnTypes.assign({operands[0].getType()});
641   return success();
642 }
643 
644 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
645     MLIRContext *context, Optional<Location> location, ValueRange operands,
646     DictionaryAttr attributes, RegionRange regions,
647     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
648   // Create return type consisting of the last element of the first operand.
649   auto operandType = *operands.getTypes().begin();
650   auto sval = operandType.dyn_cast<ShapedType>();
651   if (!sval) {
652     return emitOptionalError(location, "only shaped type operands allowed");
653   }
654   int64_t dim =
655       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
656   auto type = IntegerType::get(17, context);
657   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
658   return success();
659 }
660 
661 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
662     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
663   shapes = SmallVector<Value, 1>{
664       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
665   return success();
666 }
667 
668 //===----------------------------------------------------------------------===//
669 // Test SideEffect interfaces
670 //===----------------------------------------------------------------------===//
671 
672 namespace {
673 /// A test resource for side effects.
674 struct TestResource : public SideEffects::Resource::Base<TestResource> {
675   StringRef getName() final { return "<Test>"; }
676 };
677 } // end anonymous namespace
678 
679 void SideEffectOp::getEffects(
680     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
681   // Check for an effects attribute on the op instance.
682   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
683   if (!effectsAttr)
684     return;
685 
686   // If there is one, it is an array of dictionary attributes that hold
687   // information on the effects of this operation.
688   for (Attribute element : effectsAttr) {
689     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
690 
691     // Get the specific memory effect.
692     MemoryEffects::Effect *effect =
693         StringSwitch<MemoryEffects::Effect *>(
694             effectElement.get("effect").cast<StringAttr>().getValue())
695             .Case("allocate", MemoryEffects::Allocate::get())
696             .Case("free", MemoryEffects::Free::get())
697             .Case("read", MemoryEffects::Read::get())
698             .Case("write", MemoryEffects::Write::get());
699 
700     // Check for a non-default resource to use.
701     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
702     if (effectElement.get("test_resource"))
703       resource = TestResource::get();
704 
705     // Check for a result to affect.
706     if (effectElement.get("on_result"))
707       effects.emplace_back(effect, getResult(), resource);
708     else if (Attribute ref = effectElement.get("on_reference"))
709       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
710     else
711       effects.emplace_back(effect, resource);
712   }
713 }
714 
715 void SideEffectOp::getEffects(
716     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
717   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
718   if (!effectsAttr)
719     return;
720 
721   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
722 }
723 
724 //===----------------------------------------------------------------------===//
725 // StringAttrPrettyNameOp
726 //===----------------------------------------------------------------------===//
727 
728 // This op has fancy handling of its SSA result name.
729 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
730                                                OperationState &result) {
731   // Add the result types.
732   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
733     result.addTypes(parser.getBuilder().getIntegerType(32));
734 
735   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
736     return failure();
737 
738   // If the attribute dictionary contains no 'names' attribute, infer it from
739   // the SSA name (if specified).
740   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
741     return attr.first == "names";
742   });
743 
744   // If there was no name specified, check to see if there was a useful name
745   // specified in the asm file.
746   if (hadNames || parser.getNumResults() == 0)
747     return success();
748 
749   SmallVector<StringRef, 4> names;
750   auto *context = result.getContext();
751 
752   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
753     auto resultName = parser.getResultName(i);
754     StringRef nameStr;
755     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
756       nameStr = resultName.first;
757 
758     names.push_back(nameStr);
759   }
760 
761   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
762   result.attributes.push_back({Identifier::get("names", context), namesAttr});
763   return success();
764 }
765 
766 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
767   p << "test.string_attr_pretty_name";
768 
769   // Note that we only need to print the "name" attribute if the asmprinter
770   // result name disagrees with it.  This can happen in strange cases, e.g.
771   // when there are conflicts.
772   bool namesDisagree = op.names().size() != op.getNumResults();
773 
774   SmallString<32> resultNameStr;
775   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
776     resultNameStr.clear();
777     llvm::raw_svector_ostream tmpStream(resultNameStr);
778     p.printOperand(op.getResult(i), tmpStream);
779 
780     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
781     if (!expectedName ||
782         tmpStream.str().drop_front() != expectedName.getValue()) {
783       namesDisagree = true;
784     }
785   }
786 
787   if (namesDisagree)
788     p.printOptionalAttrDictWithKeyword(op.getAttrs());
789   else
790     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
791 }
792 
793 // We set the SSA name in the asm syntax to the contents of the name
794 // attribute.
795 void StringAttrPrettyNameOp::getAsmResultNames(
796     function_ref<void(Value, StringRef)> setNameFn) {
797 
798   auto value = names();
799   for (size_t i = 0, e = value.size(); i != e; ++i)
800     if (auto str = value[i].dyn_cast<StringAttr>())
801       if (!str.getValue().empty())
802         setNameFn(getResult(i), str.getValue());
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // RegionIfOp
807 //===----------------------------------------------------------------------===//
808 
809 static void print(OpAsmPrinter &p, RegionIfOp op) {
810   p << RegionIfOp::getOperationName() << " ";
811   p.printOperands(op.getOperands());
812   p << ": " << op.getOperandTypes();
813   p.printArrowTypeList(op.getResultTypes());
814   p << " then";
815   p.printRegion(op.thenRegion(),
816                 /*printEntryBlockArgs=*/true,
817                 /*printBlockTerminators=*/true);
818   p << " else";
819   p.printRegion(op.elseRegion(),
820                 /*printEntryBlockArgs=*/true,
821                 /*printBlockTerminators=*/true);
822   p << " join";
823   p.printRegion(op.joinRegion(),
824                 /*printEntryBlockArgs=*/true,
825                 /*printBlockTerminators=*/true);
826 }
827 
828 static ParseResult parseRegionIfOp(OpAsmParser &parser,
829                                    OperationState &result) {
830   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
831   SmallVector<Type, 2> operandTypes;
832 
833   result.regions.reserve(3);
834   Region *thenRegion = result.addRegion();
835   Region *elseRegion = result.addRegion();
836   Region *joinRegion = result.addRegion();
837 
838   // Parse operand, type and arrow type lists.
839   if (parser.parseOperandList(operandInfos) ||
840       parser.parseColonTypeList(operandTypes) ||
841       parser.parseArrowTypeList(result.types))
842     return failure();
843 
844   // Parse all attached regions.
845   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
846       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
847       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
848     return failure();
849 
850   return parser.resolveOperands(operandInfos, operandTypes,
851                                 parser.getCurrentLocation(), result.operands);
852 }
853 
854 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
855   assert(index < 2 && "invalid region index");
856   return getOperands();
857 }
858 
859 void RegionIfOp::getSuccessorRegions(
860     Optional<unsigned> index, ArrayRef<Attribute> operands,
861     SmallVectorImpl<RegionSuccessor> &regions) {
862   // We always branch to the join region.
863   if (index.hasValue()) {
864     if (index.getValue() < 2)
865       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
866     else
867       regions.push_back(RegionSuccessor(getResults()));
868     return;
869   }
870 
871   // The then and else regions are the entry regions of this op.
872   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
873   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
874 }
875 
876 #include "TestOpEnums.cpp.inc"
877 #include "TestOpInterfaces.cpp.inc"
878 #include "TestOpStructs.cpp.inc"
879 #include "TestTypeInterfaces.cpp.inc"
880 
881 #define GET_OP_CLASSES
882 #include "TestOps.cpp.inc"
883