1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace mlir;
23 
24 void mlir::registerTestDialect(DialectRegistry &registry) {
25   registry.insert<TestDialect>();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 
34 // Test support for interacting with the AsmPrinter.
35 struct TestOpAsmInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
38   void getAsmResultNames(Operation *op,
39                          OpAsmSetValueNameFn setNameFn) const final {
40     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
41       setNameFn(asmOp, "result");
42   }
43 
44   void getAsmBlockArgumentNames(Block *block,
45                                 OpAsmSetValueNameFn setNameFn) const final {
46     auto op = block->getParentOp();
47     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
48     if (!arrayAttr)
49       return;
50     auto args = block->getArguments();
51     auto e = std::min(arrayAttr.size(), args.size());
52     for (unsigned i = 0; i < e; ++i) {
53       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
54         setNameFn(args[i], strAttr.getValue());
55     }
56   }
57 };
58 
59 struct TestDialectFoldInterface : public DialectFoldInterface {
60   using DialectFoldInterface::DialectFoldInterface;
61 
62   /// Registered hook to check if the given region, which is attached to an
63   /// operation that is *not* isolated from above, should be used when
64   /// materializing constants.
65   bool shouldMaterializeInto(Region *region) const final {
66     // If this is a one region operation, then insert into it.
67     return isa<OneRegionOp>(region->getParentOp());
68   }
69 };
70 
71 /// This class defines the interface for handling inlining with standard
72 /// operations.
73 struct TestInlinerInterface : public DialectInlinerInterface {
74   using DialectInlinerInterface::DialectInlinerInterface;
75 
76   //===--------------------------------------------------------------------===//
77   // Analysis Hooks
78   //===--------------------------------------------------------------------===//
79 
80   bool isLegalToInline(Operation *call, Operation *callable,
81                        bool wouldBeCloned) const final {
82     // Don't allow inlining calls that are marked `noinline`.
83     return !call->hasAttr("noinline");
84   }
85   bool isLegalToInline(Region *, Region *, bool,
86                        BlockAndValueMapping &) const final {
87     // Inlining into test dialect regions is legal.
88     return true;
89   }
90   bool isLegalToInline(Operation *, Region *, bool,
91                        BlockAndValueMapping &) const final {
92     return true;
93   }
94 
95   bool shouldAnalyzeRecursively(Operation *op) const final {
96     // Analyze recursively if this is not a functional region operation, it
97     // froms a separate functional scope.
98     return !isa<FunctionalRegionOp>(op);
99   }
100 
101   //===--------------------------------------------------------------------===//
102   // Transformation Hooks
103   //===--------------------------------------------------------------------===//
104 
105   /// Handle the given inlined terminator by replacing it with a new operation
106   /// as necessary.
107   void handleTerminator(Operation *op,
108                         ArrayRef<Value> valuesToRepl) const final {
109     // Only handle "test.return" here.
110     auto returnOp = dyn_cast<TestReturnOp>(op);
111     if (!returnOp)
112       return;
113 
114     // Replace the values directly with the return operands.
115     assert(returnOp.getNumOperands() == valuesToRepl.size());
116     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
117       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
118   }
119 
120   /// Attempt to materialize a conversion for a type mismatch between a call
121   /// from this dialect, and a callable region. This method should generate an
122   /// operation that takes 'input' as the only operand, and produces a single
123   /// result of 'resultType'. If a conversion can not be generated, nullptr
124   /// should be returned.
125   Operation *materializeCallConversion(OpBuilder &builder, Value input,
126                                        Type resultType,
127                                        Location conversionLoc) const final {
128     // Only allow conversion for i16/i32 types.
129     if (!(resultType.isSignlessInteger(16) ||
130           resultType.isSignlessInteger(32)) ||
131         !(input.getType().isSignlessInteger(16) ||
132           input.getType().isSignlessInteger(32)))
133       return nullptr;
134     return builder.create<TestCastOp>(conversionLoc, resultType, input);
135   }
136 };
137 } // end anonymous namespace
138 
139 //===----------------------------------------------------------------------===//
140 // TestDialect
141 //===----------------------------------------------------------------------===//
142 
143 void TestDialect::initialize() {
144   addOperations<
145 #define GET_OP_LIST
146 #include "TestOps.cpp.inc"
147       >();
148   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
149                 TestInlinerInterface>();
150   addTypes<TestType, TestRecursiveType,
151 #define GET_TYPEDEF_LIST
152 #include "TestTypeDefs.cpp.inc"
153            >();
154   allowUnknownOperations();
155 }
156 
157 static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
158                           llvm::SetVector<Type> &stack) {
159   StringRef typeTag;
160   if (failed(parser.parseKeyword(&typeTag)))
161     return Type();
162 
163   auto genType = generatedTypeParser(ctxt, parser, typeTag);
164   if (genType != Type())
165     return genType;
166 
167   if (typeTag == "test_type")
168     return TestType::get(parser.getBuilder().getContext());
169 
170   if (typeTag != "test_rec")
171     return Type();
172 
173   StringRef name;
174   if (parser.parseLess() || parser.parseKeyword(&name))
175     return Type();
176   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
177 
178   // If this type already has been parsed above in the stack, expect just the
179   // name.
180   if (stack.contains(rec)) {
181     if (failed(parser.parseGreater()))
182       return Type();
183     return rec;
184   }
185 
186   // Otherwise, parse the body and update the type.
187   if (failed(parser.parseComma()))
188     return Type();
189   stack.insert(rec);
190   Type subtype = parseTestType(ctxt, parser, stack);
191   stack.pop_back();
192   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
193     return Type();
194 
195   return rec;
196 }
197 
198 Type TestDialect::parseType(DialectAsmParser &parser) const {
199   llvm::SetVector<Type> stack;
200   return parseTestType(getContext(), parser, stack);
201 }
202 
203 static void printTestType(Type type, DialectAsmPrinter &printer,
204                           llvm::SetVector<Type> &stack) {
205   if (succeeded(generatedTypePrinter(type, printer)))
206     return;
207   if (type.isa<TestType>()) {
208     printer << "test_type";
209     return;
210   }
211 
212   auto rec = type.cast<TestRecursiveType>();
213   printer << "test_rec<" << rec.getName();
214   if (!stack.contains(rec)) {
215     printer << ", ";
216     stack.insert(rec);
217     printTestType(rec.getBody(), printer, stack);
218     stack.pop_back();
219   }
220   printer << ">";
221 }
222 
223 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
224   llvm::SetVector<Type> stack;
225   printTestType(type, printer, stack);
226 }
227 
228 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
229                                                     NamedAttribute namedAttr) {
230   if (namedAttr.first == "test.invalid_attr")
231     return op->emitError() << "invalid to use 'test.invalid_attr'";
232   return success();
233 }
234 
235 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
236                                                     unsigned regionIndex,
237                                                     unsigned argIndex,
238                                                     NamedAttribute namedAttr) {
239   if (namedAttr.first == "test.invalid_attr")
240     return op->emitError() << "invalid to use 'test.invalid_attr'";
241   return success();
242 }
243 
244 LogicalResult
245 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
246                                          unsigned resultIndex,
247                                          NamedAttribute namedAttr) {
248   if (namedAttr.first == "test.invalid_attr")
249     return op->emitError() << "invalid to use 'test.invalid_attr'";
250   return success();
251 }
252 
253 //===----------------------------------------------------------------------===//
254 // TestBranchOp
255 //===----------------------------------------------------------------------===//
256 
257 Optional<MutableOperandRange>
258 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
259   assert(index == 0 && "invalid successor index");
260   return targetOperandsMutable();
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // TestFoldToCallOp
265 //===----------------------------------------------------------------------===//
266 
267 namespace {
268 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
269   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
270 
271   LogicalResult matchAndRewrite(FoldToCallOp op,
272                                 PatternRewriter &rewriter) const override {
273     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
274                                         ValueRange());
275     return success();
276   }
277 };
278 } // end anonymous namespace
279 
280 void FoldToCallOp::getCanonicalizationPatterns(
281     OwningRewritePatternList &results, MLIRContext *context) {
282   results.insert<FoldToCallOpPattern>(context);
283 }
284 
285 //===----------------------------------------------------------------------===//
286 // Test Format* operations
287 //===----------------------------------------------------------------------===//
288 
289 //===----------------------------------------------------------------------===//
290 // Parsing
291 
292 static ParseResult parseCustomDirectiveOperands(
293     OpAsmParser &parser, OpAsmParser::OperandType &operand,
294     Optional<OpAsmParser::OperandType> &optOperand,
295     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
296   if (parser.parseOperand(operand))
297     return failure();
298   if (succeeded(parser.parseOptionalComma())) {
299     optOperand.emplace();
300     if (parser.parseOperand(*optOperand))
301       return failure();
302   }
303   if (parser.parseArrow() || parser.parseLParen() ||
304       parser.parseOperandList(varOperands) || parser.parseRParen())
305     return failure();
306   return success();
307 }
308 static ParseResult
309 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
310                             Type &optOperandType,
311                             SmallVectorImpl<Type> &varOperandTypes) {
312   if (parser.parseColon())
313     return failure();
314 
315   if (parser.parseType(operandType))
316     return failure();
317   if (succeeded(parser.parseOptionalComma())) {
318     if (parser.parseType(optOperandType))
319       return failure();
320   }
321   if (parser.parseArrow() || parser.parseLParen() ||
322       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
323     return failure();
324   return success();
325 }
326 static ParseResult
327 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
328                                  Type optOperandType,
329                                  const SmallVectorImpl<Type> &varOperandTypes) {
330   if (parser.parseKeyword("type_refs_capture"))
331     return failure();
332 
333   Type operandType2, optOperandType2;
334   SmallVector<Type, 1> varOperandTypes2;
335   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
336                                   varOperandTypes2))
337     return failure();
338 
339   if (operandType != operandType2 || optOperandType != optOperandType2 ||
340       varOperandTypes != varOperandTypes2)
341     return failure();
342 
343   return success();
344 }
345 static ParseResult parseCustomDirectiveOperandsAndTypes(
346     OpAsmParser &parser, OpAsmParser::OperandType &operand,
347     Optional<OpAsmParser::OperandType> &optOperand,
348     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
349     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
350   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
351       parseCustomDirectiveResults(parser, operandType, optOperandType,
352                                   varOperandTypes))
353     return failure();
354   return success();
355 }
356 static ParseResult parseCustomDirectiveRegions(
357     OpAsmParser &parser, Region &region,
358     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
359   if (parser.parseRegion(region))
360     return failure();
361   if (failed(parser.parseOptionalComma()))
362     return success();
363   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
364   if (parser.parseRegion(*varRegion))
365     return failure();
366   varRegions.emplace_back(std::move(varRegion));
367   return success();
368 }
369 static ParseResult
370 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
371                                SmallVectorImpl<Block *> &varSuccessors) {
372   if (parser.parseSuccessor(successor))
373     return failure();
374   if (failed(parser.parseOptionalComma()))
375     return success();
376   Block *varSuccessor;
377   if (parser.parseSuccessor(varSuccessor))
378     return failure();
379   varSuccessors.append(2, varSuccessor);
380   return success();
381 }
382 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
383                                                   IntegerAttr &attr,
384                                                   IntegerAttr &optAttr) {
385   if (parser.parseAttribute(attr))
386     return failure();
387   if (succeeded(parser.parseOptionalComma())) {
388     if (parser.parseAttribute(optAttr))
389       return failure();
390   }
391   return success();
392 }
393 
394 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
395                                                 NamedAttrList &attrs) {
396   return parser.parseOptionalAttrDict(attrs);
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // Printing
401 
402 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
403                                          Value operand, Value optOperand,
404                                          OperandRange varOperands) {
405   printer << operand;
406   if (optOperand)
407     printer << ", " << optOperand;
408   printer << " -> (" << varOperands << ")";
409 }
410 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
411                                         Type operandType, Type optOperandType,
412                                         TypeRange varOperandTypes) {
413   printer << " : " << operandType;
414   if (optOperandType)
415     printer << ", " << optOperandType;
416   printer << " -> (" << varOperandTypes << ")";
417 }
418 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
419                                              Operation *op, Type operandType,
420                                              Type optOperandType,
421                                              TypeRange varOperandTypes) {
422   printer << " type_refs_capture ";
423   printCustomDirectiveResults(printer, op, operandType, optOperandType,
424                               varOperandTypes);
425 }
426 static void printCustomDirectiveOperandsAndTypes(
427     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
428     OperandRange varOperands, Type operandType, Type optOperandType,
429     TypeRange varOperandTypes) {
430   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
431   printCustomDirectiveResults(printer, op, operandType, optOperandType,
432                               varOperandTypes);
433 }
434 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
435                                         Region &region,
436                                         MutableArrayRef<Region> varRegions) {
437   printer.printRegion(region);
438   if (!varRegions.empty()) {
439     printer << ", ";
440     for (Region &region : varRegions)
441       printer.printRegion(region);
442   }
443 }
444 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
445                                            Block *successor,
446                                            SuccessorRange varSuccessors) {
447   printer << successor;
448   if (!varSuccessors.empty())
449     printer << ", " << varSuccessors.front();
450 }
451 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
452                                            Attribute attribute,
453                                            Attribute optAttribute) {
454   printer << attribute;
455   if (optAttribute)
456     printer << ", " << optAttribute;
457 }
458 
459 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
460                                          MutableDictionaryAttr attrs) {
461   printer.printOptionalAttrDict(attrs.getAttrs());
462 }
463 //===----------------------------------------------------------------------===//
464 // Test IsolatedRegionOp - parse passthrough region arguments.
465 //===----------------------------------------------------------------------===//
466 
467 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
468                                          OperationState &result) {
469   OpAsmParser::OperandType argInfo;
470   Type argType = parser.getBuilder().getIndexType();
471 
472   // Parse the input operand.
473   if (parser.parseOperand(argInfo) ||
474       parser.resolveOperand(argInfo, argType, result.operands))
475     return failure();
476 
477   // Parse the body region, and reuse the operand info as the argument info.
478   Region *body = result.addRegion();
479   return parser.parseRegion(*body, argInfo, argType,
480                             /*enableNameShadowing=*/true);
481 }
482 
483 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
484   p << "test.isolated_region ";
485   p.printOperand(op.getOperand());
486   p.shadowRegionArgs(op.region(), op.getOperand());
487   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // Test SSACFGRegionOp
492 //===----------------------------------------------------------------------===//
493 
494 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
495   return RegionKind::SSACFG;
496 }
497 
498 //===----------------------------------------------------------------------===//
499 // Test GraphRegionOp
500 //===----------------------------------------------------------------------===//
501 
502 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
503                                       OperationState &result) {
504   // Parse the body region, and reuse the operand info as the argument info.
505   Region *body = result.addRegion();
506   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
507 }
508 
509 static void print(OpAsmPrinter &p, GraphRegionOp op) {
510   p << "test.graph_region ";
511   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
512 }
513 
514 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
515   return RegionKind::Graph;
516 }
517 
518 //===----------------------------------------------------------------------===//
519 // Test AffineScopeOp
520 //===----------------------------------------------------------------------===//
521 
522 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
523                                       OperationState &result) {
524   // Parse the body region, and reuse the operand info as the argument info.
525   Region *body = result.addRegion();
526   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
527 }
528 
529 static void print(OpAsmPrinter &p, AffineScopeOp op) {
530   p << "test.affine_scope ";
531   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // Test parser.
536 //===----------------------------------------------------------------------===//
537 
538 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
539                                          OperationState &result) {
540   StringRef keyword;
541   if (parser.parseKeyword(&keyword))
542     return failure();
543   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
544   return success();
545 }
546 
547 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
548   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
553 
554 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
555                                          OperationState &result) {
556   if (parser.parseKeyword("wraps"))
557     return failure();
558 
559   // Parse the wrapped op in a region
560   Region &body = *result.addRegion();
561   body.push_back(new Block);
562   Block &block = body.back();
563   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
564   if (!wrapped_op)
565     return failure();
566 
567   // Create a return terminator in the inner region, pass as operand to the
568   // terminator the returned values from the wrapped operation.
569   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
570   OpBuilder builder(parser.getBuilder().getContext());
571   builder.setInsertionPointToEnd(&block);
572   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
573 
574   // Get the results type for the wrapping op from the terminator operands.
575   Operation &return_op = body.back().back();
576   result.types.append(return_op.operand_type_begin(),
577                       return_op.operand_type_end());
578 
579   // Use the location of the wrapped op for the "test.wrapping_region" op.
580   result.location = wrapped_op->getLoc();
581 
582   return success();
583 }
584 
585 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
586   p << op.getOperationName() << " wraps ";
587   p.printGenericOp(&op.region().front().front());
588 }
589 
590 //===----------------------------------------------------------------------===//
591 // Test PolyForOp - parse list of region arguments.
592 //===----------------------------------------------------------------------===//
593 
594 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
595   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
596   // Parse list of region arguments without a delimiter.
597   if (parser.parseRegionArgumentList(ivsInfo))
598     return failure();
599 
600   // Parse the body region.
601   Region *body = result.addRegion();
602   auto &builder = parser.getBuilder();
603   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
604   return parser.parseRegion(*body, ivsInfo, argTypes);
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // Test removing op with inner ops.
609 //===----------------------------------------------------------------------===//
610 
611 namespace {
612 struct TestRemoveOpWithInnerOps
613     : public OpRewritePattern<TestOpWithRegionPattern> {
614   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
615 
616   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
617                                 PatternRewriter &rewriter) const override {
618     rewriter.eraseOp(op);
619     return success();
620   }
621 };
622 } // end anonymous namespace
623 
624 void TestOpWithRegionPattern::getCanonicalizationPatterns(
625     OwningRewritePatternList &results, MLIRContext *context) {
626   results.insert<TestRemoveOpWithInnerOps>(context);
627 }
628 
629 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
630   return operand();
631 }
632 
633 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
634   return getValue();
635 }
636 
637 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
638     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
639   for (Value input : this->operands()) {
640     results.push_back(input);
641   }
642   return success();
643 }
644 
645 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
646   assert(operands.size() == 1);
647   if (operands.front()) {
648     setAttr("attr", operands.front());
649     return getResult();
650   }
651   return {};
652 }
653 
654 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
655     MLIRContext *, Optional<Location> location, ValueRange operands,
656     DictionaryAttr attributes, RegionRange regions,
657     SmallVectorImpl<Type> &inferredReturnTypes) {
658   if (operands[0].getType() != operands[1].getType()) {
659     return emitOptionalError(location, "operand type mismatch ",
660                              operands[0].getType(), " vs ",
661                              operands[1].getType());
662   }
663   inferredReturnTypes.assign({operands[0].getType()});
664   return success();
665 }
666 
667 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
668     MLIRContext *context, Optional<Location> location, ValueRange operands,
669     DictionaryAttr attributes, RegionRange regions,
670     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
671   // Create return type consisting of the last element of the first operand.
672   auto operandType = *operands.getTypes().begin();
673   auto sval = operandType.dyn_cast<ShapedType>();
674   if (!sval) {
675     return emitOptionalError(location, "only shaped type operands allowed");
676   }
677   int64_t dim =
678       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
679   auto type = IntegerType::get(17, context);
680   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
681   return success();
682 }
683 
684 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
685     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
686   shapes = SmallVector<Value, 1>{
687       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
688   return success();
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // Test SideEffect interfaces
693 //===----------------------------------------------------------------------===//
694 
695 namespace {
696 /// A test resource for side effects.
697 struct TestResource : public SideEffects::Resource::Base<TestResource> {
698   StringRef getName() final { return "<Test>"; }
699 };
700 } // end anonymous namespace
701 
702 void SideEffectOp::getEffects(
703     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
704   // Check for an effects attribute on the op instance.
705   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
706   if (!effectsAttr)
707     return;
708 
709   // If there is one, it is an array of dictionary attributes that hold
710   // information on the effects of this operation.
711   for (Attribute element : effectsAttr) {
712     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
713 
714     // Get the specific memory effect.
715     MemoryEffects::Effect *effect =
716         StringSwitch<MemoryEffects::Effect *>(
717             effectElement.get("effect").cast<StringAttr>().getValue())
718             .Case("allocate", MemoryEffects::Allocate::get())
719             .Case("free", MemoryEffects::Free::get())
720             .Case("read", MemoryEffects::Read::get())
721             .Case("write", MemoryEffects::Write::get());
722 
723     // Check for a result to affect.
724     Value value;
725     if (effectElement.get("on_result"))
726       value = getResult();
727 
728     // Check for a non-default resource to use.
729     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
730     if (effectElement.get("test_resource"))
731       resource = TestResource::get();
732 
733     effects.emplace_back(effect, value, resource);
734   }
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // StringAttrPrettyNameOp
739 //===----------------------------------------------------------------------===//
740 
741 // This op has fancy handling of its SSA result name.
742 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
743                                                OperationState &result) {
744   // Add the result types.
745   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
746     result.addTypes(parser.getBuilder().getIntegerType(32));
747 
748   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
749     return failure();
750 
751   // If the attribute dictionary contains no 'names' attribute, infer it from
752   // the SSA name (if specified).
753   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
754     return attr.first == "names";
755   });
756 
757   // If there was no name specified, check to see if there was a useful name
758   // specified in the asm file.
759   if (hadNames || parser.getNumResults() == 0)
760     return success();
761 
762   SmallVector<StringRef, 4> names;
763   auto *context = result.getContext();
764 
765   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
766     auto resultName = parser.getResultName(i);
767     StringRef nameStr;
768     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
769       nameStr = resultName.first;
770 
771     names.push_back(nameStr);
772   }
773 
774   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
775   result.attributes.push_back({Identifier::get("names", context), namesAttr});
776   return success();
777 }
778 
779 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
780   p << "test.string_attr_pretty_name";
781 
782   // Note that we only need to print the "name" attribute if the asmprinter
783   // result name disagrees with it.  This can happen in strange cases, e.g.
784   // when there are conflicts.
785   bool namesDisagree = op.names().size() != op.getNumResults();
786 
787   SmallString<32> resultNameStr;
788   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
789     resultNameStr.clear();
790     llvm::raw_svector_ostream tmpStream(resultNameStr);
791     p.printOperand(op.getResult(i), tmpStream);
792 
793     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
794     if (!expectedName ||
795         tmpStream.str().drop_front() != expectedName.getValue()) {
796       namesDisagree = true;
797     }
798   }
799 
800   if (namesDisagree)
801     p.printOptionalAttrDictWithKeyword(op.getAttrs());
802   else
803     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
804 }
805 
806 // We set the SSA name in the asm syntax to the contents of the name
807 // attribute.
808 void StringAttrPrettyNameOp::getAsmResultNames(
809     function_ref<void(Value, StringRef)> setNameFn) {
810 
811   auto value = names();
812   for (size_t i = 0, e = value.size(); i != e; ++i)
813     if (auto str = value[i].dyn_cast<StringAttr>())
814       if (!str.getValue().empty())
815         setNameFn(getResult(i), str.getValue());
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // RegionIfOp
820 //===----------------------------------------------------------------------===//
821 
822 static void print(OpAsmPrinter &p, RegionIfOp op) {
823   p << RegionIfOp::getOperationName() << " ";
824   p.printOperands(op.getOperands());
825   p << ": " << op.getOperandTypes();
826   p.printArrowTypeList(op.getResultTypes());
827   p << " then";
828   p.printRegion(op.thenRegion(),
829                 /*printEntryBlockArgs=*/true,
830                 /*printBlockTerminators=*/true);
831   p << " else";
832   p.printRegion(op.elseRegion(),
833                 /*printEntryBlockArgs=*/true,
834                 /*printBlockTerminators=*/true);
835   p << " join";
836   p.printRegion(op.joinRegion(),
837                 /*printEntryBlockArgs=*/true,
838                 /*printBlockTerminators=*/true);
839 }
840 
841 static ParseResult parseRegionIfOp(OpAsmParser &parser,
842                                    OperationState &result) {
843   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
844   SmallVector<Type, 2> operandTypes;
845 
846   result.regions.reserve(3);
847   Region *thenRegion = result.addRegion();
848   Region *elseRegion = result.addRegion();
849   Region *joinRegion = result.addRegion();
850 
851   // Parse operand, type and arrow type lists.
852   if (parser.parseOperandList(operandInfos) ||
853       parser.parseColonTypeList(operandTypes) ||
854       parser.parseArrowTypeList(result.types))
855     return failure();
856 
857   // Parse all attached regions.
858   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
859       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
860       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
861     return failure();
862 
863   return parser.resolveOperands(operandInfos, operandTypes,
864                                 parser.getCurrentLocation(), result.operands);
865 }
866 
867 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
868   assert(index < 2 && "invalid region index");
869   return getOperands();
870 }
871 
872 void RegionIfOp::getSuccessorRegions(
873     Optional<unsigned> index, ArrayRef<Attribute> operands,
874     SmallVectorImpl<RegionSuccessor> &regions) {
875   // We always branch to the join region.
876   if (index.hasValue()) {
877     if (index.getValue() < 2)
878       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
879     else
880       regions.push_back(RegionSuccessor(getResults()));
881     return;
882   }
883 
884   // The then and else regions are the entry regions of this op.
885   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
886   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
887 }
888 
889 #include "TestOpEnums.cpp.inc"
890 #include "TestOpStructs.cpp.inc"
891 #include "TestTypeInterfaces.cpp.inc"
892 
893 #define GET_OP_CLASSES
894 #include "TestOps.cpp.inc"
895