1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace mlir;
23 
24 void mlir::registerTestDialect(DialectRegistry &registry) {
25   registry.insert<TestDialect>();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 
34 // Test support for interacting with the AsmPrinter.
35 struct TestOpAsmInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
38   void getAsmResultNames(Operation *op,
39                          OpAsmSetValueNameFn setNameFn) const final {
40     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
41       setNameFn(asmOp, "result");
42   }
43 
44   void getAsmBlockArgumentNames(Block *block,
45                                 OpAsmSetValueNameFn setNameFn) const final {
46     auto op = block->getParentOp();
47     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
48     if (!arrayAttr)
49       return;
50     auto args = block->getArguments();
51     auto e = std::min(arrayAttr.size(), args.size());
52     for (unsigned i = 0; i < e; ++i) {
53       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
54         setNameFn(args[i], strAttr.getValue());
55     }
56   }
57 };
58 
59 struct TestDialectFoldInterface : public DialectFoldInterface {
60   using DialectFoldInterface::DialectFoldInterface;
61 
62   /// Registered hook to check if the given region, which is attached to an
63   /// operation that is *not* isolated from above, should be used when
64   /// materializing constants.
65   bool shouldMaterializeInto(Region *region) const final {
66     // If this is a one region operation, then insert into it.
67     return isa<OneRegionOp>(region->getParentOp());
68   }
69 };
70 
71 /// This class defines the interface for handling inlining with standard
72 /// operations.
73 struct TestInlinerInterface : public DialectInlinerInterface {
74   using DialectInlinerInterface::DialectInlinerInterface;
75 
76   //===--------------------------------------------------------------------===//
77   // Analysis Hooks
78   //===--------------------------------------------------------------------===//
79 
80   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
81     // Inlining into test dialect regions is legal.
82     return true;
83   }
84   bool isLegalToInline(Operation *, Region *,
85                        BlockAndValueMapping &) const final {
86     return true;
87   }
88 
89   bool shouldAnalyzeRecursively(Operation *op) const final {
90     // Analyze recursively if this is not a functional region operation, it
91     // froms a separate functional scope.
92     return !isa<FunctionalRegionOp>(op);
93   }
94 
95   //===--------------------------------------------------------------------===//
96   // Transformation Hooks
97   //===--------------------------------------------------------------------===//
98 
99   /// Handle the given inlined terminator by replacing it with a new operation
100   /// as necessary.
101   void handleTerminator(Operation *op,
102                         ArrayRef<Value> valuesToRepl) const final {
103     // Only handle "test.return" here.
104     auto returnOp = dyn_cast<TestReturnOp>(op);
105     if (!returnOp)
106       return;
107 
108     // Replace the values directly with the return operands.
109     assert(returnOp.getNumOperands() == valuesToRepl.size());
110     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
111       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
112   }
113 
114   /// Attempt to materialize a conversion for a type mismatch between a call
115   /// from this dialect, and a callable region. This method should generate an
116   /// operation that takes 'input' as the only operand, and produces a single
117   /// result of 'resultType'. If a conversion can not be generated, nullptr
118   /// should be returned.
119   Operation *materializeCallConversion(OpBuilder &builder, Value input,
120                                        Type resultType,
121                                        Location conversionLoc) const final {
122     // Only allow conversion for i16/i32 types.
123     if (!(resultType.isSignlessInteger(16) ||
124           resultType.isSignlessInteger(32)) ||
125         !(input.getType().isSignlessInteger(16) ||
126           input.getType().isSignlessInteger(32)))
127       return nullptr;
128     return builder.create<TestCastOp>(conversionLoc, resultType, input);
129   }
130 };
131 } // end anonymous namespace
132 
133 //===----------------------------------------------------------------------===//
134 // TestDialect
135 //===----------------------------------------------------------------------===//
136 
137 void TestDialect::initialize() {
138   addOperations<
139 #define GET_OP_LIST
140 #include "TestOps.cpp.inc"
141       >();
142   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
143                 TestInlinerInterface>();
144   addTypes<TestType, TestRecursiveType,
145 #define GET_TYPEDEF_LIST
146 #include "TestTypeDefs.cpp.inc"
147            >();
148   allowUnknownOperations();
149 }
150 
151 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
152                           llvm::SetVector<Type> &stack) {
153   StringRef typeTag;
154   if (failed(parser.parseKeyword(&typeTag)))
155     return Type();
156 
157   auto genType = generatedTypeParser(ctxt, parser, typeTag);
158   if (genType != Type())
159     return genType;
160 
161   if (typeTag == "test_type")
162     return TestType::get(parser.getBuilder().getContext());
163 
164   if (typeTag != "test_rec")
165     return Type();
166 
167   StringRef name;
168   if (parser.parseLess() || parser.parseKeyword(&name))
169     return Type();
170   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
171 
172   // If this type already has been parsed above in the stack, expect just the
173   // name.
174   if (stack.contains(rec)) {
175     if (failed(parser.parseGreater()))
176       return Type();
177     return rec;
178   }
179 
180   // Otherwise, parse the body and update the type.
181   if (failed(parser.parseComma()))
182     return Type();
183   stack.insert(rec);
184   Type subtype = parseTestType(ctxt, parser, stack);
185   stack.pop_back();
186   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
187     return Type();
188 
189   return rec;
190 }
191 
192 Type TestDialect::parseType(DialectAsmParser &parser) const {
193   llvm::SetVector<Type> stack;
194   return parseTestType(getContext(), parser, stack);
195 }
196 
197 static void printTestType(Type type, DialectAsmPrinter &printer,
198                           llvm::SetVector<Type> &stack) {
199   if (succeeded(generatedTypePrinter(type, printer)))
200     return;
201   if (type.isa<TestType>()) {
202     printer << "test_type";
203     return;
204   }
205 
206   auto rec = type.cast<TestRecursiveType>();
207   printer << "test_rec<" << rec.getName();
208   if (!stack.contains(rec)) {
209     printer << ", ";
210     stack.insert(rec);
211     printTestType(rec.getBody(), printer, stack);
212     stack.pop_back();
213   }
214   printer << ">";
215 }
216 
217 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
218   llvm::SetVector<Type> stack;
219   printTestType(type, printer, stack);
220 }
221 
222 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
223                                                     NamedAttribute namedAttr) {
224   if (namedAttr.first == "test.invalid_attr")
225     return op->emitError() << "invalid to use 'test.invalid_attr'";
226   return success();
227 }
228 
229 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
230                                                     unsigned regionIndex,
231                                                     unsigned argIndex,
232                                                     NamedAttribute namedAttr) {
233   if (namedAttr.first == "test.invalid_attr")
234     return op->emitError() << "invalid to use 'test.invalid_attr'";
235   return success();
236 }
237 
238 LogicalResult
239 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
240                                          unsigned resultIndex,
241                                          NamedAttribute namedAttr) {
242   if (namedAttr.first == "test.invalid_attr")
243     return op->emitError() << "invalid to use 'test.invalid_attr'";
244   return success();
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // TestBranchOp
249 //===----------------------------------------------------------------------===//
250 
251 Optional<MutableOperandRange>
252 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
253   assert(index == 0 && "invalid successor index");
254   return targetOperandsMutable();
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // TestFoldToCallOp
259 //===----------------------------------------------------------------------===//
260 
261 namespace {
262 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
263   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
264 
265   LogicalResult matchAndRewrite(FoldToCallOp op,
266                                 PatternRewriter &rewriter) const override {
267     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
268                                         ValueRange());
269     return success();
270   }
271 };
272 } // end anonymous namespace
273 
274 void FoldToCallOp::getCanonicalizationPatterns(
275     OwningRewritePatternList &results, MLIRContext *context) {
276   results.insert<FoldToCallOpPattern>(context);
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // Test Format* operations
281 //===----------------------------------------------------------------------===//
282 
283 //===----------------------------------------------------------------------===//
284 // Parsing
285 
286 static ParseResult parseCustomDirectiveOperands(
287     OpAsmParser &parser, OpAsmParser::OperandType &operand,
288     Optional<OpAsmParser::OperandType> &optOperand,
289     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
290   if (parser.parseOperand(operand))
291     return failure();
292   if (succeeded(parser.parseOptionalComma())) {
293     optOperand.emplace();
294     if (parser.parseOperand(*optOperand))
295       return failure();
296   }
297   if (parser.parseArrow() || parser.parseLParen() ||
298       parser.parseOperandList(varOperands) || parser.parseRParen())
299     return failure();
300   return success();
301 }
302 static ParseResult
303 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
304                             Type &optOperandType,
305                             SmallVectorImpl<Type> &varOperandTypes) {
306   if (parser.parseColon())
307     return failure();
308 
309   if (parser.parseType(operandType))
310     return failure();
311   if (succeeded(parser.parseOptionalComma())) {
312     if (parser.parseType(optOperandType))
313       return failure();
314   }
315   if (parser.parseArrow() || parser.parseLParen() ||
316       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
317     return failure();
318   return success();
319 }
320 static ParseResult
321 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
322                                  Type optOperandType,
323                                  const SmallVectorImpl<Type> &varOperandTypes) {
324   if (parser.parseKeyword("type_refs_capture"))
325     return failure();
326 
327   Type operandType2, optOperandType2;
328   SmallVector<Type, 1> varOperandTypes2;
329   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
330                                   varOperandTypes2))
331     return failure();
332 
333   if (operandType != operandType2 || optOperandType != optOperandType2 ||
334       varOperandTypes != varOperandTypes2)
335     return failure();
336 
337   return success();
338 }
339 static ParseResult parseCustomDirectiveOperandsAndTypes(
340     OpAsmParser &parser, OpAsmParser::OperandType &operand,
341     Optional<OpAsmParser::OperandType> &optOperand,
342     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
343     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
344   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
345       parseCustomDirectiveResults(parser, operandType, optOperandType,
346                                   varOperandTypes))
347     return failure();
348   return success();
349 }
350 static ParseResult parseCustomDirectiveRegions(
351     OpAsmParser &parser, Region &region,
352     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
353   if (parser.parseRegion(region))
354     return failure();
355   if (failed(parser.parseOptionalComma()))
356     return success();
357   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
358   if (parser.parseRegion(*varRegion))
359     return failure();
360   varRegions.emplace_back(std::move(varRegion));
361   return success();
362 }
363 static ParseResult
364 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
365                                SmallVectorImpl<Block *> &varSuccessors) {
366   if (parser.parseSuccessor(successor))
367     return failure();
368   if (failed(parser.parseOptionalComma()))
369     return success();
370   Block *varSuccessor;
371   if (parser.parseSuccessor(varSuccessor))
372     return failure();
373   varSuccessors.append(2, varSuccessor);
374   return success();
375 }
376 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
377                                                   IntegerAttr &attr,
378                                                   IntegerAttr &optAttr) {
379   if (parser.parseAttribute(attr))
380     return failure();
381   if (succeeded(parser.parseOptionalComma())) {
382     if (parser.parseAttribute(optAttr))
383       return failure();
384   }
385   return success();
386 }
387 
388 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
389                                                 NamedAttrList &attrs) {
390   return parser.parseOptionalAttrDict(attrs);
391 }
392 
393 //===----------------------------------------------------------------------===//
394 // Printing
395 
396 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
397                                          Value operand, Value optOperand,
398                                          OperandRange varOperands) {
399   printer << operand;
400   if (optOperand)
401     printer << ", " << optOperand;
402   printer << " -> (" << varOperands << ")";
403 }
404 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
405                                         Type operandType, Type optOperandType,
406                                         TypeRange varOperandTypes) {
407   printer << " : " << operandType;
408   if (optOperandType)
409     printer << ", " << optOperandType;
410   printer << " -> (" << varOperandTypes << ")";
411 }
412 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
413                                              Operation *op, Type operandType,
414                                              Type optOperandType,
415                                              TypeRange varOperandTypes) {
416   printer << " type_refs_capture ";
417   printCustomDirectiveResults(printer, op, operandType, optOperandType,
418                               varOperandTypes);
419 }
420 static void printCustomDirectiveOperandsAndTypes(
421     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
422     OperandRange varOperands, Type operandType, Type optOperandType,
423     TypeRange varOperandTypes) {
424   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
425   printCustomDirectiveResults(printer, op, operandType, optOperandType,
426                               varOperandTypes);
427 }
428 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
429                                         Region &region,
430                                         MutableArrayRef<Region> varRegions) {
431   printer.printRegion(region);
432   if (!varRegions.empty()) {
433     printer << ", ";
434     for (Region &region : varRegions)
435       printer.printRegion(region);
436   }
437 }
438 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
439                                            Block *successor,
440                                            SuccessorRange varSuccessors) {
441   printer << successor;
442   if (!varSuccessors.empty())
443     printer << ", " << varSuccessors.front();
444 }
445 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
446                                            Attribute attribute,
447                                            Attribute optAttribute) {
448   printer << attribute;
449   if (optAttribute)
450     printer << ", " << optAttribute;
451 }
452 
453 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
454                                          MutableDictionaryAttr attrs) {
455   printer.printOptionalAttrDict(attrs.getAttrs());
456 }
457 //===----------------------------------------------------------------------===//
458 // Test IsolatedRegionOp - parse passthrough region arguments.
459 //===----------------------------------------------------------------------===//
460 
461 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
462                                          OperationState &result) {
463   OpAsmParser::OperandType argInfo;
464   Type argType = parser.getBuilder().getIndexType();
465 
466   // Parse the input operand.
467   if (parser.parseOperand(argInfo) ||
468       parser.resolveOperand(argInfo, argType, result.operands))
469     return failure();
470 
471   // Parse the body region, and reuse the operand info as the argument info.
472   Region *body = result.addRegion();
473   return parser.parseRegion(*body, argInfo, argType,
474                             /*enableNameShadowing=*/true);
475 }
476 
477 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
478   p << "test.isolated_region ";
479   p.printOperand(op.getOperand());
480   p.shadowRegionArgs(op.region(), op.getOperand());
481   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
482 }
483 
484 //===----------------------------------------------------------------------===//
485 // Test SSACFGRegionOp
486 //===----------------------------------------------------------------------===//
487 
488 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
489   return RegionKind::SSACFG;
490 }
491 
492 //===----------------------------------------------------------------------===//
493 // Test GraphRegionOp
494 //===----------------------------------------------------------------------===//
495 
496 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
497                                       OperationState &result) {
498   // Parse the body region, and reuse the operand info as the argument info.
499   Region *body = result.addRegion();
500   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
501 }
502 
503 static void print(OpAsmPrinter &p, GraphRegionOp op) {
504   p << "test.graph_region ";
505   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
506 }
507 
508 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
509   return RegionKind::Graph;
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // Test AffineScopeOp
514 //===----------------------------------------------------------------------===//
515 
516 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
517                                       OperationState &result) {
518   // Parse the body region, and reuse the operand info as the argument info.
519   Region *body = result.addRegion();
520   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
521 }
522 
523 static void print(OpAsmPrinter &p, AffineScopeOp op) {
524   p << "test.affine_scope ";
525   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // Test parser.
530 //===----------------------------------------------------------------------===//
531 
532 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
533                                          OperationState &result) {
534   StringRef keyword;
535   if (parser.parseKeyword(&keyword))
536     return failure();
537   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
538   return success();
539 }
540 
541 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
542   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
543 }
544 
545 //===----------------------------------------------------------------------===//
546 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
547 
548 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
549                                          OperationState &result) {
550   if (parser.parseKeyword("wraps"))
551     return failure();
552 
553   // Parse the wrapped op in a region
554   Region &body = *result.addRegion();
555   body.push_back(new Block);
556   Block &block = body.back();
557   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
558   if (!wrapped_op)
559     return failure();
560 
561   // Create a return terminator in the inner region, pass as operand to the
562   // terminator the returned values from the wrapped operation.
563   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
564   OpBuilder builder(parser.getBuilder().getContext());
565   builder.setInsertionPointToEnd(&block);
566   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
567 
568   // Get the results type for the wrapping op from the terminator operands.
569   Operation &return_op = body.back().back();
570   result.types.append(return_op.operand_type_begin(),
571                       return_op.operand_type_end());
572 
573   // Use the location of the wrapped op for the "test.wrapping_region" op.
574   result.location = wrapped_op->getLoc();
575 
576   return success();
577 }
578 
579 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
580   p << op.getOperationName() << " wraps ";
581   p.printGenericOp(&op.region().front().front());
582 }
583 
584 //===----------------------------------------------------------------------===//
585 // Test PolyForOp - parse list of region arguments.
586 //===----------------------------------------------------------------------===//
587 
588 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
589   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
590   // Parse list of region arguments without a delimiter.
591   if (parser.parseRegionArgumentList(ivsInfo))
592     return failure();
593 
594   // Parse the body region.
595   Region *body = result.addRegion();
596   auto &builder = parser.getBuilder();
597   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
598   return parser.parseRegion(*body, ivsInfo, argTypes);
599 }
600 
601 //===----------------------------------------------------------------------===//
602 // Test removing op with inner ops.
603 //===----------------------------------------------------------------------===//
604 
605 namespace {
606 struct TestRemoveOpWithInnerOps
607     : public OpRewritePattern<TestOpWithRegionPattern> {
608   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
609 
610   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
611                                 PatternRewriter &rewriter) const override {
612     rewriter.eraseOp(op);
613     return success();
614   }
615 };
616 } // end anonymous namespace
617 
618 void TestOpWithRegionPattern::getCanonicalizationPatterns(
619     OwningRewritePatternList &results, MLIRContext *context) {
620   results.insert<TestRemoveOpWithInnerOps>(context);
621 }
622 
623 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
624   return operand();
625 }
626 
627 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
628   return getValue();
629 }
630 
631 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
632     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
633   for (Value input : this->operands()) {
634     results.push_back(input);
635   }
636   return success();
637 }
638 
639 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
640   assert(operands.size() == 1);
641   if (operands.front()) {
642     setAttr("attr", operands.front());
643     return getResult();
644   }
645   return {};
646 }
647 
648 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
649     MLIRContext *, Optional<Location> location, ValueRange operands,
650     DictionaryAttr attributes, RegionRange regions,
651     SmallVectorImpl<Type> &inferredReturnTypes) {
652   if (operands[0].getType() != operands[1].getType()) {
653     return emitOptionalError(location, "operand type mismatch ",
654                              operands[0].getType(), " vs ",
655                              operands[1].getType());
656   }
657   inferredReturnTypes.assign({operands[0].getType()});
658   return success();
659 }
660 
661 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
662     MLIRContext *context, Optional<Location> location, ValueRange operands,
663     DictionaryAttr attributes, RegionRange regions,
664     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
665   // Create return type consisting of the last element of the first operand.
666   auto operandType = *operands.getTypes().begin();
667   auto sval = operandType.dyn_cast<ShapedType>();
668   if (!sval) {
669     return emitOptionalError(location, "only shaped type operands allowed");
670   }
671   int64_t dim =
672       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
673   auto type = IntegerType::get(17, context);
674   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
675   return success();
676 }
677 
678 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
679     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
680   shapes = SmallVector<Value, 1>{
681       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
682   return success();
683 }
684 
685 //===----------------------------------------------------------------------===//
686 // Test SideEffect interfaces
687 //===----------------------------------------------------------------------===//
688 
689 namespace {
690 /// A test resource for side effects.
691 struct TestResource : public SideEffects::Resource::Base<TestResource> {
692   StringRef getName() final { return "<Test>"; }
693 };
694 } // end anonymous namespace
695 
696 void SideEffectOp::getEffects(
697     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
698   // Check for an effects attribute on the op instance.
699   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
700   if (!effectsAttr)
701     return;
702 
703   // If there is one, it is an array of dictionary attributes that hold
704   // information on the effects of this operation.
705   for (Attribute element : effectsAttr) {
706     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
707 
708     // Get the specific memory effect.
709     MemoryEffects::Effect *effect =
710         StringSwitch<MemoryEffects::Effect *>(
711             effectElement.get("effect").cast<StringAttr>().getValue())
712             .Case("allocate", MemoryEffects::Allocate::get())
713             .Case("free", MemoryEffects::Free::get())
714             .Case("read", MemoryEffects::Read::get())
715             .Case("write", MemoryEffects::Write::get());
716 
717     // Check for a result to affect.
718     Value value;
719     if (effectElement.get("on_result"))
720       value = getResult();
721 
722     // Check for a non-default resource to use.
723     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
724     if (effectElement.get("test_resource"))
725       resource = TestResource::get();
726 
727     effects.emplace_back(effect, value, resource);
728   }
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // StringAttrPrettyNameOp
733 //===----------------------------------------------------------------------===//
734 
735 // This op has fancy handling of its SSA result name.
736 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
737                                                OperationState &result) {
738   // Add the result types.
739   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
740     result.addTypes(parser.getBuilder().getIntegerType(32));
741 
742   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
743     return failure();
744 
745   // If the attribute dictionary contains no 'names' attribute, infer it from
746   // the SSA name (if specified).
747   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
748     return attr.first == "names";
749   });
750 
751   // If there was no name specified, check to see if there was a useful name
752   // specified in the asm file.
753   if (hadNames || parser.getNumResults() == 0)
754     return success();
755 
756   SmallVector<StringRef, 4> names;
757   auto *context = result.getContext();
758 
759   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
760     auto resultName = parser.getResultName(i);
761     StringRef nameStr;
762     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
763       nameStr = resultName.first;
764 
765     names.push_back(nameStr);
766   }
767 
768   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
769   result.attributes.push_back({Identifier::get("names", context), namesAttr});
770   return success();
771 }
772 
773 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
774   p << "test.string_attr_pretty_name";
775 
776   // Note that we only need to print the "name" attribute if the asmprinter
777   // result name disagrees with it.  This can happen in strange cases, e.g.
778   // when there are conflicts.
779   bool namesDisagree = op.names().size() != op.getNumResults();
780 
781   SmallString<32> resultNameStr;
782   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
783     resultNameStr.clear();
784     llvm::raw_svector_ostream tmpStream(resultNameStr);
785     p.printOperand(op.getResult(i), tmpStream);
786 
787     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
788     if (!expectedName ||
789         tmpStream.str().drop_front() != expectedName.getValue()) {
790       namesDisagree = true;
791     }
792   }
793 
794   if (namesDisagree)
795     p.printOptionalAttrDictWithKeyword(op.getAttrs());
796   else
797     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
798 }
799 
800 // We set the SSA name in the asm syntax to the contents of the name
801 // attribute.
802 void StringAttrPrettyNameOp::getAsmResultNames(
803     function_ref<void(Value, StringRef)> setNameFn) {
804 
805   auto value = names();
806   for (size_t i = 0, e = value.size(); i != e; ++i)
807     if (auto str = value[i].dyn_cast<StringAttr>())
808       if (!str.getValue().empty())
809         setNameFn(getResult(i), str.getValue());
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // RegionIfOp
814 //===----------------------------------------------------------------------===//
815 
816 static void print(OpAsmPrinter &p, RegionIfOp op) {
817   p << RegionIfOp::getOperationName() << " ";
818   p.printOperands(op.getOperands());
819   p << ": " << op.getOperandTypes();
820   p.printArrowTypeList(op.getResultTypes());
821   p << " then";
822   p.printRegion(op.thenRegion(),
823                 /*printEntryBlockArgs=*/true,
824                 /*printBlockTerminators=*/true);
825   p << " else";
826   p.printRegion(op.elseRegion(),
827                 /*printEntryBlockArgs=*/true,
828                 /*printBlockTerminators=*/true);
829   p << " join";
830   p.printRegion(op.joinRegion(),
831                 /*printEntryBlockArgs=*/true,
832                 /*printBlockTerminators=*/true);
833 }
834 
835 static ParseResult parseRegionIfOp(OpAsmParser &parser,
836                                    OperationState &result) {
837   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
838   SmallVector<Type, 2> operandTypes;
839 
840   result.regions.reserve(3);
841   Region *thenRegion = result.addRegion();
842   Region *elseRegion = result.addRegion();
843   Region *joinRegion = result.addRegion();
844 
845   // Parse operand, type and arrow type lists.
846   if (parser.parseOperandList(operandInfos) ||
847       parser.parseColonTypeList(operandTypes) ||
848       parser.parseArrowTypeList(result.types))
849     return failure();
850 
851   // Parse all attached regions.
852   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
853       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
854       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
855     return failure();
856 
857   return parser.resolveOperands(operandInfos, operandTypes,
858                                 parser.getCurrentLocation(), result.operands);
859 }
860 
861 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
862   assert(index < 2 && "invalid region index");
863   return getOperands();
864 }
865 
866 void RegionIfOp::getSuccessorRegions(
867     Optional<unsigned> index, ArrayRef<Attribute> operands,
868     SmallVectorImpl<RegionSuccessor> &regions) {
869   // We always branch to the join region.
870   if (index.hasValue()) {
871     if (index.getValue() < 2)
872       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
873     else
874       regions.push_back(RegionSuccessor(getResults()));
875     return;
876   }
877 
878   // The then and else regions are the entry regions of this op.
879   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
880   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
881 }
882 
883 #include "TestOpEnums.cpp.inc"
884 #include "TestOpStructs.cpp.inc"
885 #include "TestTypeInterfaces.cpp.inc"
886 
887 #define GET_OP_CLASSES
888 #include "TestOps.cpp.inc"
889