1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/DialectImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/StringSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::test;
23 
24 void mlir::test::registerTestDialect(DialectRegistry &registry) {
25   registry.insert<TestDialect>();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 
34 // Test support for interacting with the AsmPrinter.
35 struct TestOpAsmInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
38   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
39     StringAttr strAttr = attr.dyn_cast<StringAttr>();
40     if (!strAttr)
41       return failure();
42 
43     // Check the contents of the string attribute to see what the test alias
44     // should be named.
45     Optional<StringRef> aliasName =
46         StringSwitch<Optional<StringRef>>(strAttr.getValue())
47             .Case("alias_test:dot_in_name", StringRef("test.alias"))
48             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
49             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
50             .Case("alias_test:sanitize_conflict_a",
51                   StringRef("test_alias_conflict0"))
52             .Case("alias_test:sanitize_conflict_b",
53                   StringRef("test_alias_conflict0_"))
54             .Default(llvm::None);
55     if (!aliasName)
56       return failure();
57 
58     os << *aliasName;
59     return success();
60   }
61 
62   void getAsmResultNames(Operation *op,
63                          OpAsmSetValueNameFn setNameFn) const final {
64     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
65       setNameFn(asmOp, "result");
66   }
67 
68   void getAsmBlockArgumentNames(Block *block,
69                                 OpAsmSetValueNameFn setNameFn) const final {
70     auto op = block->getParentOp();
71     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
72     if (!arrayAttr)
73       return;
74     auto args = block->getArguments();
75     auto e = std::min(arrayAttr.size(), args.size());
76     for (unsigned i = 0; i < e; ++i) {
77       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
78         setNameFn(args[i], strAttr.getValue());
79     }
80   }
81 };
82 
83 struct TestDialectFoldInterface : public DialectFoldInterface {
84   using DialectFoldInterface::DialectFoldInterface;
85 
86   /// Registered hook to check if the given region, which is attached to an
87   /// operation that is *not* isolated from above, should be used when
88   /// materializing constants.
89   bool shouldMaterializeInto(Region *region) const final {
90     // If this is a one region operation, then insert into it.
91     return isa<OneRegionOp>(region->getParentOp());
92   }
93 };
94 
95 /// This class defines the interface for handling inlining with standard
96 /// operations.
97 struct TestInlinerInterface : public DialectInlinerInterface {
98   using DialectInlinerInterface::DialectInlinerInterface;
99 
100   //===--------------------------------------------------------------------===//
101   // Analysis Hooks
102   //===--------------------------------------------------------------------===//
103 
104   bool isLegalToInline(Operation *call, Operation *callable,
105                        bool wouldBeCloned) const final {
106     // Don't allow inlining calls that are marked `noinline`.
107     return !call->hasAttr("noinline");
108   }
109   bool isLegalToInline(Region *, Region *, bool,
110                        BlockAndValueMapping &) const final {
111     // Inlining into test dialect regions is legal.
112     return true;
113   }
114   bool isLegalToInline(Operation *, Region *, bool,
115                        BlockAndValueMapping &) const final {
116     return true;
117   }
118 
119   bool shouldAnalyzeRecursively(Operation *op) const final {
120     // Analyze recursively if this is not a functional region operation, it
121     // froms a separate functional scope.
122     return !isa<FunctionalRegionOp>(op);
123   }
124 
125   //===--------------------------------------------------------------------===//
126   // Transformation Hooks
127   //===--------------------------------------------------------------------===//
128 
129   /// Handle the given inlined terminator by replacing it with a new operation
130   /// as necessary.
131   void handleTerminator(Operation *op,
132                         ArrayRef<Value> valuesToRepl) const final {
133     // Only handle "test.return" here.
134     auto returnOp = dyn_cast<TestReturnOp>(op);
135     if (!returnOp)
136       return;
137 
138     // Replace the values directly with the return operands.
139     assert(returnOp.getNumOperands() == valuesToRepl.size());
140     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
141       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
142   }
143 
144   /// Attempt to materialize a conversion for a type mismatch between a call
145   /// from this dialect, and a callable region. This method should generate an
146   /// operation that takes 'input' as the only operand, and produces a single
147   /// result of 'resultType'. If a conversion can not be generated, nullptr
148   /// should be returned.
149   Operation *materializeCallConversion(OpBuilder &builder, Value input,
150                                        Type resultType,
151                                        Location conversionLoc) const final {
152     // Only allow conversion for i16/i32 types.
153     if (!(resultType.isSignlessInteger(16) ||
154           resultType.isSignlessInteger(32)) ||
155         !(input.getType().isSignlessInteger(16) ||
156           input.getType().isSignlessInteger(32)))
157       return nullptr;
158     return builder.create<TestCastOp>(conversionLoc, resultType, input);
159   }
160 };
161 } // end anonymous namespace
162 
163 //===----------------------------------------------------------------------===//
164 // TestDialect
165 //===----------------------------------------------------------------------===//
166 
167 void TestDialect::initialize() {
168   addOperations<
169 #define GET_OP_LIST
170 #include "TestOps.cpp.inc"
171       >();
172   addAttributes<
173 #define GET_ATTRDEF_LIST
174 #include "TestAttrDefs.cpp.inc"
175       >();
176   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
177                 TestInlinerInterface>();
178   addTypes<TestType, TestRecursiveType,
179 #define GET_TYPEDEF_LIST
180 #include "TestTypeDefs.cpp.inc"
181            >();
182   allowUnknownOperations();
183 }
184 
185 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
186                                             Type type, Location loc) {
187   return builder.create<TestOpConstant>(loc, type, value);
188 }
189 
190 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
191                                                     NamedAttribute namedAttr) {
192   if (namedAttr.first == "test.invalid_attr")
193     return op->emitError() << "invalid to use 'test.invalid_attr'";
194   return success();
195 }
196 
197 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
198                                                     unsigned regionIndex,
199                                                     unsigned argIndex,
200                                                     NamedAttribute namedAttr) {
201   if (namedAttr.first == "test.invalid_attr")
202     return op->emitError() << "invalid to use 'test.invalid_attr'";
203   return success();
204 }
205 
206 LogicalResult
207 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
208                                          unsigned resultIndex,
209                                          NamedAttribute namedAttr) {
210   if (namedAttr.first == "test.invalid_attr")
211     return op->emitError() << "invalid to use 'test.invalid_attr'";
212   return success();
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // TestBranchOp
217 //===----------------------------------------------------------------------===//
218 
219 Optional<MutableOperandRange>
220 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
221   assert(index == 0 && "invalid successor index");
222   return targetOperandsMutable();
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // TestFoldToCallOp
227 //===----------------------------------------------------------------------===//
228 
229 namespace {
230 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
231   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
232 
233   LogicalResult matchAndRewrite(FoldToCallOp op,
234                                 PatternRewriter &rewriter) const override {
235     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
236                                         ValueRange());
237     return success();
238   }
239 };
240 } // end anonymous namespace
241 
242 void FoldToCallOp::getCanonicalizationPatterns(
243     OwningRewritePatternList &results, MLIRContext *context) {
244   results.insert<FoldToCallOpPattern>(context);
245 }
246 
247 //===----------------------------------------------------------------------===//
248 // Test Format* operations
249 //===----------------------------------------------------------------------===//
250 
251 //===----------------------------------------------------------------------===//
252 // Parsing
253 
254 static ParseResult parseCustomDirectiveOperands(
255     OpAsmParser &parser, OpAsmParser::OperandType &operand,
256     Optional<OpAsmParser::OperandType> &optOperand,
257     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
258   if (parser.parseOperand(operand))
259     return failure();
260   if (succeeded(parser.parseOptionalComma())) {
261     optOperand.emplace();
262     if (parser.parseOperand(*optOperand))
263       return failure();
264   }
265   if (parser.parseArrow() || parser.parseLParen() ||
266       parser.parseOperandList(varOperands) || parser.parseRParen())
267     return failure();
268   return success();
269 }
270 static ParseResult
271 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
272                             Type &optOperandType,
273                             SmallVectorImpl<Type> &varOperandTypes) {
274   if (parser.parseColon())
275     return failure();
276 
277   if (parser.parseType(operandType))
278     return failure();
279   if (succeeded(parser.parseOptionalComma())) {
280     if (parser.parseType(optOperandType))
281       return failure();
282   }
283   if (parser.parseArrow() || parser.parseLParen() ||
284       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
285     return failure();
286   return success();
287 }
288 static ParseResult
289 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
290                                  Type optOperandType,
291                                  const SmallVectorImpl<Type> &varOperandTypes) {
292   if (parser.parseKeyword("type_refs_capture"))
293     return failure();
294 
295   Type operandType2, optOperandType2;
296   SmallVector<Type, 1> varOperandTypes2;
297   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
298                                   varOperandTypes2))
299     return failure();
300 
301   if (operandType != operandType2 || optOperandType != optOperandType2 ||
302       varOperandTypes != varOperandTypes2)
303     return failure();
304 
305   return success();
306 }
307 static ParseResult parseCustomDirectiveOperandsAndTypes(
308     OpAsmParser &parser, OpAsmParser::OperandType &operand,
309     Optional<OpAsmParser::OperandType> &optOperand,
310     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
311     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
312   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
313       parseCustomDirectiveResults(parser, operandType, optOperandType,
314                                   varOperandTypes))
315     return failure();
316   return success();
317 }
318 static ParseResult parseCustomDirectiveRegions(
319     OpAsmParser &parser, Region &region,
320     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
321   if (parser.parseRegion(region))
322     return failure();
323   if (failed(parser.parseOptionalComma()))
324     return success();
325   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
326   if (parser.parseRegion(*varRegion))
327     return failure();
328   varRegions.emplace_back(std::move(varRegion));
329   return success();
330 }
331 static ParseResult
332 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
333                                SmallVectorImpl<Block *> &varSuccessors) {
334   if (parser.parseSuccessor(successor))
335     return failure();
336   if (failed(parser.parseOptionalComma()))
337     return success();
338   Block *varSuccessor;
339   if (parser.parseSuccessor(varSuccessor))
340     return failure();
341   varSuccessors.append(2, varSuccessor);
342   return success();
343 }
344 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
345                                                   IntegerAttr &attr,
346                                                   IntegerAttr &optAttr) {
347   if (parser.parseAttribute(attr))
348     return failure();
349   if (succeeded(parser.parseOptionalComma())) {
350     if (parser.parseAttribute(optAttr))
351       return failure();
352   }
353   return success();
354 }
355 
356 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
357                                                 NamedAttrList &attrs) {
358   return parser.parseOptionalAttrDict(attrs);
359 }
360 static ParseResult parseCustomDirectiveOptionalOperandRef(
361     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
362   int64_t operandCount = 0;
363   if (parser.parseInteger(operandCount))
364     return failure();
365   bool expectedOptionalOperand = operandCount == 0;
366   return success(expectedOptionalOperand != optOperand.hasValue());
367 }
368 
369 //===----------------------------------------------------------------------===//
370 // Printing
371 
372 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
373                                          Value operand, Value optOperand,
374                                          OperandRange varOperands) {
375   printer << operand;
376   if (optOperand)
377     printer << ", " << optOperand;
378   printer << " -> (" << varOperands << ")";
379 }
380 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
381                                         Type operandType, Type optOperandType,
382                                         TypeRange varOperandTypes) {
383   printer << " : " << operandType;
384   if (optOperandType)
385     printer << ", " << optOperandType;
386   printer << " -> (" << varOperandTypes << ")";
387 }
388 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
389                                              Operation *op, Type operandType,
390                                              Type optOperandType,
391                                              TypeRange varOperandTypes) {
392   printer << " type_refs_capture ";
393   printCustomDirectiveResults(printer, op, operandType, optOperandType,
394                               varOperandTypes);
395 }
396 static void printCustomDirectiveOperandsAndTypes(
397     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
398     OperandRange varOperands, Type operandType, Type optOperandType,
399     TypeRange varOperandTypes) {
400   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
401   printCustomDirectiveResults(printer, op, operandType, optOperandType,
402                               varOperandTypes);
403 }
404 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
405                                         Region &region,
406                                         MutableArrayRef<Region> varRegions) {
407   printer.printRegion(region);
408   if (!varRegions.empty()) {
409     printer << ", ";
410     for (Region &region : varRegions)
411       printer.printRegion(region);
412   }
413 }
414 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
415                                            Block *successor,
416                                            SuccessorRange varSuccessors) {
417   printer << successor;
418   if (!varSuccessors.empty())
419     printer << ", " << varSuccessors.front();
420 }
421 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
422                                            Attribute attribute,
423                                            Attribute optAttribute) {
424   printer << attribute;
425   if (optAttribute)
426     printer << ", " << optAttribute;
427 }
428 
429 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
430                                          DictionaryAttr attrs) {
431   printer.printOptionalAttrDict(attrs.getValue());
432 }
433 
434 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
435                                                    Operation *op,
436                                                    Value optOperand) {
437   printer << (optOperand ? "1" : "0");
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // Test IsolatedRegionOp - parse passthrough region arguments.
442 //===----------------------------------------------------------------------===//
443 
444 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
445                                          OperationState &result) {
446   OpAsmParser::OperandType argInfo;
447   Type argType = parser.getBuilder().getIndexType();
448 
449   // Parse the input operand.
450   if (parser.parseOperand(argInfo) ||
451       parser.resolveOperand(argInfo, argType, result.operands))
452     return failure();
453 
454   // Parse the body region, and reuse the operand info as the argument info.
455   Region *body = result.addRegion();
456   return parser.parseRegion(*body, argInfo, argType,
457                             /*enableNameShadowing=*/true);
458 }
459 
460 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
461   p << "test.isolated_region ";
462   p.printOperand(op.getOperand());
463   p.shadowRegionArgs(op.region(), op.getOperand());
464   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
465 }
466 
467 //===----------------------------------------------------------------------===//
468 // Test SSACFGRegionOp
469 //===----------------------------------------------------------------------===//
470 
471 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
472   return RegionKind::SSACFG;
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Test GraphRegionOp
477 //===----------------------------------------------------------------------===//
478 
479 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
480                                       OperationState &result) {
481   // Parse the body region, and reuse the operand info as the argument info.
482   Region *body = result.addRegion();
483   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
484 }
485 
486 static void print(OpAsmPrinter &p, GraphRegionOp op) {
487   p << "test.graph_region ";
488   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
489 }
490 
491 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
492   return RegionKind::Graph;
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // Test AffineScopeOp
497 //===----------------------------------------------------------------------===//
498 
499 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
500                                       OperationState &result) {
501   // Parse the body region, and reuse the operand info as the argument info.
502   Region *body = result.addRegion();
503   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
504 }
505 
506 static void print(OpAsmPrinter &p, AffineScopeOp op) {
507   p << "test.affine_scope ";
508   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Test parser.
513 //===----------------------------------------------------------------------===//
514 
515 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
516                                               OperationState &result) {
517   if (parser.parseOptionalColon())
518     return success();
519   uint64_t numResults;
520   if (parser.parseInteger(numResults))
521     return failure();
522 
523   IndexType type = parser.getBuilder().getIndexType();
524   for (unsigned i = 0; i < numResults; ++i)
525     result.addTypes(type);
526   return success();
527 }
528 
529 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
530   p << ParseIntegerLiteralOp::getOperationName();
531   if (unsigned numResults = op->getNumResults())
532     p << " : " << numResults;
533 }
534 
535 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
536                                               OperationState &result) {
537   StringRef keyword;
538   if (parser.parseKeyword(&keyword))
539     return failure();
540   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
541   return success();
542 }
543 
544 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
545   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
550 
551 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
552                                          OperationState &result) {
553   if (parser.parseKeyword("wraps"))
554     return failure();
555 
556   // Parse the wrapped op in a region
557   Region &body = *result.addRegion();
558   body.push_back(new Block);
559   Block &block = body.back();
560   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
561   if (!wrapped_op)
562     return failure();
563 
564   // Create a return terminator in the inner region, pass as operand to the
565   // terminator the returned values from the wrapped operation.
566   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
567   OpBuilder builder(parser.getBuilder().getContext());
568   builder.setInsertionPointToEnd(&block);
569   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
570 
571   // Get the results type for the wrapping op from the terminator operands.
572   Operation &return_op = body.back().back();
573   result.types.append(return_op.operand_type_begin(),
574                       return_op.operand_type_end());
575 
576   // Use the location of the wrapped op for the "test.wrapping_region" op.
577   result.location = wrapped_op->getLoc();
578 
579   return success();
580 }
581 
582 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
583   p << op.getOperationName() << " wraps ";
584   p.printGenericOp(&op.region().front().front());
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // Test PolyForOp - parse list of region arguments.
589 //===----------------------------------------------------------------------===//
590 
591 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
592   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
593   // Parse list of region arguments without a delimiter.
594   if (parser.parseRegionArgumentList(ivsInfo))
595     return failure();
596 
597   // Parse the body region.
598   Region *body = result.addRegion();
599   auto &builder = parser.getBuilder();
600   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
601   return parser.parseRegion(*body, ivsInfo, argTypes);
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // Test removing op with inner ops.
606 //===----------------------------------------------------------------------===//
607 
608 namespace {
609 struct TestRemoveOpWithInnerOps
610     : public OpRewritePattern<TestOpWithRegionPattern> {
611   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
612 
613   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
614                                 PatternRewriter &rewriter) const override {
615     rewriter.eraseOp(op);
616     return success();
617   }
618 };
619 } // end anonymous namespace
620 
621 void TestOpWithRegionPattern::getCanonicalizationPatterns(
622     OwningRewritePatternList &results, MLIRContext *context) {
623   results.insert<TestRemoveOpWithInnerOps>(context);
624 }
625 
626 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
627   return operand();
628 }
629 
630 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
631   return getValue();
632 }
633 
634 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
635     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
636   for (Value input : this->operands()) {
637     results.push_back(input);
638   }
639   return success();
640 }
641 
642 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
643   assert(operands.size() == 1);
644   if (operands.front()) {
645     (*this)->setAttr("attr", operands.front());
646     return getResult();
647   }
648   return {};
649 }
650 
651 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
652   return getOperand();
653 }
654 
655 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
656     MLIRContext *, Optional<Location> location, ValueRange operands,
657     DictionaryAttr attributes, RegionRange regions,
658     SmallVectorImpl<Type> &inferredReturnTypes) {
659   if (operands[0].getType() != operands[1].getType()) {
660     return emitOptionalError(location, "operand type mismatch ",
661                              operands[0].getType(), " vs ",
662                              operands[1].getType());
663   }
664   inferredReturnTypes.assign({operands[0].getType()});
665   return success();
666 }
667 
668 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
669     MLIRContext *context, Optional<Location> location, ValueRange operands,
670     DictionaryAttr attributes, RegionRange regions,
671     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
672   // Create return type consisting of the last element of the first operand.
673   auto operandType = *operands.getTypes().begin();
674   auto sval = operandType.dyn_cast<ShapedType>();
675   if (!sval) {
676     return emitOptionalError(location, "only shaped type operands allowed");
677   }
678   int64_t dim =
679       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
680   auto type = IntegerType::get(context, 17);
681   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
682   return success();
683 }
684 
685 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
686     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
687   shapes = SmallVector<Value, 1>{
688       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
689   return success();
690 }
691 
692 //===----------------------------------------------------------------------===//
693 // Test SideEffect interfaces
694 //===----------------------------------------------------------------------===//
695 
696 namespace {
697 /// A test resource for side effects.
698 struct TestResource : public SideEffects::Resource::Base<TestResource> {
699   StringRef getName() final { return "<Test>"; }
700 };
701 } // end anonymous namespace
702 
703 void SideEffectOp::getEffects(
704     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
705   // Check for an effects attribute on the op instance.
706   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
707   if (!effectsAttr)
708     return;
709 
710   // If there is one, it is an array of dictionary attributes that hold
711   // information on the effects of this operation.
712   for (Attribute element : effectsAttr) {
713     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
714 
715     // Get the specific memory effect.
716     MemoryEffects::Effect *effect =
717         StringSwitch<MemoryEffects::Effect *>(
718             effectElement.get("effect").cast<StringAttr>().getValue())
719             .Case("allocate", MemoryEffects::Allocate::get())
720             .Case("free", MemoryEffects::Free::get())
721             .Case("read", MemoryEffects::Read::get())
722             .Case("write", MemoryEffects::Write::get());
723 
724     // Check for a non-default resource to use.
725     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
726     if (effectElement.get("test_resource"))
727       resource = TestResource::get();
728 
729     // Check for a result to affect.
730     if (effectElement.get("on_result"))
731       effects.emplace_back(effect, getResult(), resource);
732     else if (Attribute ref = effectElement.get("on_reference"))
733       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
734     else
735       effects.emplace_back(effect, resource);
736   }
737 }
738 
739 void SideEffectOp::getEffects(
740     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
741   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
742   if (!effectsAttr)
743     return;
744 
745   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // StringAttrPrettyNameOp
750 //===----------------------------------------------------------------------===//
751 
752 // This op has fancy handling of its SSA result name.
753 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
754                                                OperationState &result) {
755   // Add the result types.
756   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
757     result.addTypes(parser.getBuilder().getIntegerType(32));
758 
759   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
760     return failure();
761 
762   // If the attribute dictionary contains no 'names' attribute, infer it from
763   // the SSA name (if specified).
764   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
765     return attr.first == "names";
766   });
767 
768   // If there was no name specified, check to see if there was a useful name
769   // specified in the asm file.
770   if (hadNames || parser.getNumResults() == 0)
771     return success();
772 
773   SmallVector<StringRef, 4> names;
774   auto *context = result.getContext();
775 
776   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
777     auto resultName = parser.getResultName(i);
778     StringRef nameStr;
779     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
780       nameStr = resultName.first;
781 
782     names.push_back(nameStr);
783   }
784 
785   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
786   result.attributes.push_back({Identifier::get("names", context), namesAttr});
787   return success();
788 }
789 
790 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
791   p << "test.string_attr_pretty_name";
792 
793   // Note that we only need to print the "name" attribute if the asmprinter
794   // result name disagrees with it.  This can happen in strange cases, e.g.
795   // when there are conflicts.
796   bool namesDisagree = op.names().size() != op.getNumResults();
797 
798   SmallString<32> resultNameStr;
799   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
800     resultNameStr.clear();
801     llvm::raw_svector_ostream tmpStream(resultNameStr);
802     p.printOperand(op.getResult(i), tmpStream);
803 
804     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
805     if (!expectedName ||
806         tmpStream.str().drop_front() != expectedName.getValue()) {
807       namesDisagree = true;
808     }
809   }
810 
811   if (namesDisagree)
812     p.printOptionalAttrDictWithKeyword(op->getAttrs());
813   else
814     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
815 }
816 
817 // We set the SSA name in the asm syntax to the contents of the name
818 // attribute.
819 void StringAttrPrettyNameOp::getAsmResultNames(
820     function_ref<void(Value, StringRef)> setNameFn) {
821 
822   auto value = names();
823   for (size_t i = 0, e = value.size(); i != e; ++i)
824     if (auto str = value[i].dyn_cast<StringAttr>())
825       if (!str.getValue().empty())
826         setNameFn(getResult(i), str.getValue());
827 }
828 
829 //===----------------------------------------------------------------------===//
830 // RegionIfOp
831 //===----------------------------------------------------------------------===//
832 
833 static void print(OpAsmPrinter &p, RegionIfOp op) {
834   p << RegionIfOp::getOperationName() << " ";
835   p.printOperands(op.getOperands());
836   p << ": " << op.getOperandTypes();
837   p.printArrowTypeList(op.getResultTypes());
838   p << " then";
839   p.printRegion(op.thenRegion(),
840                 /*printEntryBlockArgs=*/true,
841                 /*printBlockTerminators=*/true);
842   p << " else";
843   p.printRegion(op.elseRegion(),
844                 /*printEntryBlockArgs=*/true,
845                 /*printBlockTerminators=*/true);
846   p << " join";
847   p.printRegion(op.joinRegion(),
848                 /*printEntryBlockArgs=*/true,
849                 /*printBlockTerminators=*/true);
850 }
851 
852 static ParseResult parseRegionIfOp(OpAsmParser &parser,
853                                    OperationState &result) {
854   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
855   SmallVector<Type, 2> operandTypes;
856 
857   result.regions.reserve(3);
858   Region *thenRegion = result.addRegion();
859   Region *elseRegion = result.addRegion();
860   Region *joinRegion = result.addRegion();
861 
862   // Parse operand, type and arrow type lists.
863   if (parser.parseOperandList(operandInfos) ||
864       parser.parseColonTypeList(operandTypes) ||
865       parser.parseArrowTypeList(result.types))
866     return failure();
867 
868   // Parse all attached regions.
869   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
870       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
871       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
872     return failure();
873 
874   return parser.resolveOperands(operandInfos, operandTypes,
875                                 parser.getCurrentLocation(), result.operands);
876 }
877 
878 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
879   assert(index < 2 && "invalid region index");
880   return getOperands();
881 }
882 
883 void RegionIfOp::getSuccessorRegions(
884     Optional<unsigned> index, ArrayRef<Attribute> operands,
885     SmallVectorImpl<RegionSuccessor> &regions) {
886   // We always branch to the join region.
887   if (index.hasValue()) {
888     if (index.getValue() < 2)
889       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
890     else
891       regions.push_back(RegionSuccessor(getResults()));
892     return;
893   }
894 
895   // The then and else regions are the entry regions of this op.
896   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
897   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
898 }
899 
900 #include "TestOpEnums.cpp.inc"
901 #include "TestOpInterfaces.cpp.inc"
902 #include "TestOpStructs.cpp.inc"
903 #include "TestTypeInterfaces.cpp.inc"
904 
905 #define GET_OP_CLASSES
906 #include "TestOps.cpp.inc"
907