1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/DLTI/DLTI.h"
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/DialectImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "mlir/Transforms/FoldUtils.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace mlir;
23 using namespace mlir::test;
24 
25 void mlir::test::registerTestDialect(DialectRegistry &registry) {
26   registry.insert<TestDialect>();
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // TestDialect Interfaces
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 
35 // Test support for interacting with the AsmPrinter.
36 struct TestOpAsmInterface : public OpAsmDialectInterface {
37   using OpAsmDialectInterface::OpAsmDialectInterface;
38 
39   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
40     StringAttr strAttr = attr.dyn_cast<StringAttr>();
41     if (!strAttr)
42       return failure();
43 
44     // Check the contents of the string attribute to see what the test alias
45     // should be named.
46     Optional<StringRef> aliasName =
47         StringSwitch<Optional<StringRef>>(strAttr.getValue())
48             .Case("alias_test:dot_in_name", StringRef("test.alias"))
49             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
50             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
51             .Case("alias_test:sanitize_conflict_a",
52                   StringRef("test_alias_conflict0"))
53             .Case("alias_test:sanitize_conflict_b",
54                   StringRef("test_alias_conflict0_"))
55             .Default(llvm::None);
56     if (!aliasName)
57       return failure();
58 
59     os << *aliasName;
60     return success();
61   }
62 
63   void getAsmResultNames(Operation *op,
64                          OpAsmSetValueNameFn setNameFn) const final {
65     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
66       setNameFn(asmOp, "result");
67   }
68 
69   void getAsmBlockArgumentNames(Block *block,
70                                 OpAsmSetValueNameFn setNameFn) const final {
71     auto op = block->getParentOp();
72     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
73     if (!arrayAttr)
74       return;
75     auto args = block->getArguments();
76     auto e = std::min(arrayAttr.size(), args.size());
77     for (unsigned i = 0; i < e; ++i) {
78       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
79         setNameFn(args[i], strAttr.getValue());
80     }
81   }
82 };
83 
84 struct TestDialectFoldInterface : public DialectFoldInterface {
85   using DialectFoldInterface::DialectFoldInterface;
86 
87   /// Registered hook to check if the given region, which is attached to an
88   /// operation that is *not* isolated from above, should be used when
89   /// materializing constants.
90   bool shouldMaterializeInto(Region *region) const final {
91     // If this is a one region operation, then insert into it.
92     return isa<OneRegionOp>(region->getParentOp());
93   }
94 };
95 
96 /// This class defines the interface for handling inlining with standard
97 /// operations.
98 struct TestInlinerInterface : public DialectInlinerInterface {
99   using DialectInlinerInterface::DialectInlinerInterface;
100 
101   //===--------------------------------------------------------------------===//
102   // Analysis Hooks
103   //===--------------------------------------------------------------------===//
104 
105   bool isLegalToInline(Operation *call, Operation *callable,
106                        bool wouldBeCloned) const final {
107     // Don't allow inlining calls that are marked `noinline`.
108     return !call->hasAttr("noinline");
109   }
110   bool isLegalToInline(Region *, Region *, bool,
111                        BlockAndValueMapping &) const final {
112     // Inlining into test dialect regions is legal.
113     return true;
114   }
115   bool isLegalToInline(Operation *, Region *, bool,
116                        BlockAndValueMapping &) const final {
117     return true;
118   }
119 
120   bool shouldAnalyzeRecursively(Operation *op) const final {
121     // Analyze recursively if this is not a functional region operation, it
122     // froms a separate functional scope.
123     return !isa<FunctionalRegionOp>(op);
124   }
125 
126   //===--------------------------------------------------------------------===//
127   // Transformation Hooks
128   //===--------------------------------------------------------------------===//
129 
130   /// Handle the given inlined terminator by replacing it with a new operation
131   /// as necessary.
132   void handleTerminator(Operation *op,
133                         ArrayRef<Value> valuesToRepl) const final {
134     // Only handle "test.return" here.
135     auto returnOp = dyn_cast<TestReturnOp>(op);
136     if (!returnOp)
137       return;
138 
139     // Replace the values directly with the return operands.
140     assert(returnOp.getNumOperands() == valuesToRepl.size());
141     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
142       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
143   }
144 
145   /// Attempt to materialize a conversion for a type mismatch between a call
146   /// from this dialect, and a callable region. This method should generate an
147   /// operation that takes 'input' as the only operand, and produces a single
148   /// result of 'resultType'. If a conversion can not be generated, nullptr
149   /// should be returned.
150   Operation *materializeCallConversion(OpBuilder &builder, Value input,
151                                        Type resultType,
152                                        Location conversionLoc) const final {
153     // Only allow conversion for i16/i32 types.
154     if (!(resultType.isSignlessInteger(16) ||
155           resultType.isSignlessInteger(32)) ||
156         !(input.getType().isSignlessInteger(16) ||
157           input.getType().isSignlessInteger(32)))
158       return nullptr;
159     return builder.create<TestCastOp>(conversionLoc, resultType, input);
160   }
161 };
162 } // end anonymous namespace
163 
164 //===----------------------------------------------------------------------===//
165 // TestDialect
166 //===----------------------------------------------------------------------===//
167 
168 void TestDialect::initialize() {
169   addOperations<
170 #define GET_OP_LIST
171 #include "TestOps.cpp.inc"
172       >();
173   addAttributes<
174 #define GET_ATTRDEF_LIST
175 #include "TestAttrDefs.cpp.inc"
176       >();
177   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
178                 TestInlinerInterface>();
179   addTypes<TestType, TestTypeWithLayout, TestRecursiveType,
180 #define GET_TYPEDEF_LIST
181 #include "TestTypeDefs.cpp.inc"
182            >();
183   allowUnknownOperations();
184 }
185 
186 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
187                                             Type type, Location loc) {
188   return builder.create<TestOpConstant>(loc, type, value);
189 }
190 
191 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
192                                                     NamedAttribute namedAttr) {
193   if (namedAttr.first == "test.invalid_attr")
194     return op->emitError() << "invalid to use 'test.invalid_attr'";
195   return success();
196 }
197 
198 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
199                                                     unsigned regionIndex,
200                                                     unsigned argIndex,
201                                                     NamedAttribute namedAttr) {
202   if (namedAttr.first == "test.invalid_attr")
203     return op->emitError() << "invalid to use 'test.invalid_attr'";
204   return success();
205 }
206 
207 LogicalResult
208 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
209                                          unsigned resultIndex,
210                                          NamedAttribute namedAttr) {
211   if (namedAttr.first == "test.invalid_attr")
212     return op->emitError() << "invalid to use 'test.invalid_attr'";
213   return success();
214 }
215 
216 //===----------------------------------------------------------------------===//
217 // TestBranchOp
218 //===----------------------------------------------------------------------===//
219 
220 Optional<MutableOperandRange>
221 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
222   assert(index == 0 && "invalid successor index");
223   return targetOperandsMutable();
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // TestFoldToCallOp
228 //===----------------------------------------------------------------------===//
229 
230 namespace {
231 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
232   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
233 
234   LogicalResult matchAndRewrite(FoldToCallOp op,
235                                 PatternRewriter &rewriter) const override {
236     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
237                                         ValueRange());
238     return success();
239   }
240 };
241 } // end anonymous namespace
242 
243 void FoldToCallOp::getCanonicalizationPatterns(
244     OwningRewritePatternList &results, MLIRContext *context) {
245   results.insert<FoldToCallOpPattern>(context);
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // Test Format* operations
250 //===----------------------------------------------------------------------===//
251 
252 //===----------------------------------------------------------------------===//
253 // Parsing
254 
255 static ParseResult parseCustomDirectiveOperands(
256     OpAsmParser &parser, OpAsmParser::OperandType &operand,
257     Optional<OpAsmParser::OperandType> &optOperand,
258     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
259   if (parser.parseOperand(operand))
260     return failure();
261   if (succeeded(parser.parseOptionalComma())) {
262     optOperand.emplace();
263     if (parser.parseOperand(*optOperand))
264       return failure();
265   }
266   if (parser.parseArrow() || parser.parseLParen() ||
267       parser.parseOperandList(varOperands) || parser.parseRParen())
268     return failure();
269   return success();
270 }
271 static ParseResult
272 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
273                             Type &optOperandType,
274                             SmallVectorImpl<Type> &varOperandTypes) {
275   if (parser.parseColon())
276     return failure();
277 
278   if (parser.parseType(operandType))
279     return failure();
280   if (succeeded(parser.parseOptionalComma())) {
281     if (parser.parseType(optOperandType))
282       return failure();
283   }
284   if (parser.parseArrow() || parser.parseLParen() ||
285       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
286     return failure();
287   return success();
288 }
289 static ParseResult
290 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
291                                  Type optOperandType,
292                                  const SmallVectorImpl<Type> &varOperandTypes) {
293   if (parser.parseKeyword("type_refs_capture"))
294     return failure();
295 
296   Type operandType2, optOperandType2;
297   SmallVector<Type, 1> varOperandTypes2;
298   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
299                                   varOperandTypes2))
300     return failure();
301 
302   if (operandType != operandType2 || optOperandType != optOperandType2 ||
303       varOperandTypes != varOperandTypes2)
304     return failure();
305 
306   return success();
307 }
308 static ParseResult parseCustomDirectiveOperandsAndTypes(
309     OpAsmParser &parser, OpAsmParser::OperandType &operand,
310     Optional<OpAsmParser::OperandType> &optOperand,
311     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
312     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
313   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
314       parseCustomDirectiveResults(parser, operandType, optOperandType,
315                                   varOperandTypes))
316     return failure();
317   return success();
318 }
319 static ParseResult parseCustomDirectiveRegions(
320     OpAsmParser &parser, Region &region,
321     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
322   if (parser.parseRegion(region))
323     return failure();
324   if (failed(parser.parseOptionalComma()))
325     return success();
326   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
327   if (parser.parseRegion(*varRegion))
328     return failure();
329   varRegions.emplace_back(std::move(varRegion));
330   return success();
331 }
332 static ParseResult
333 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
334                                SmallVectorImpl<Block *> &varSuccessors) {
335   if (parser.parseSuccessor(successor))
336     return failure();
337   if (failed(parser.parseOptionalComma()))
338     return success();
339   Block *varSuccessor;
340   if (parser.parseSuccessor(varSuccessor))
341     return failure();
342   varSuccessors.append(2, varSuccessor);
343   return success();
344 }
345 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
346                                                   IntegerAttr &attr,
347                                                   IntegerAttr &optAttr) {
348   if (parser.parseAttribute(attr))
349     return failure();
350   if (succeeded(parser.parseOptionalComma())) {
351     if (parser.parseAttribute(optAttr))
352       return failure();
353   }
354   return success();
355 }
356 
357 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
358                                                 NamedAttrList &attrs) {
359   return parser.parseOptionalAttrDict(attrs);
360 }
361 static ParseResult parseCustomDirectiveOptionalOperandRef(
362     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
363   int64_t operandCount = 0;
364   if (parser.parseInteger(operandCount))
365     return failure();
366   bool expectedOptionalOperand = operandCount == 0;
367   return success(expectedOptionalOperand != optOperand.hasValue());
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // Printing
372 
373 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
374                                          Value operand, Value optOperand,
375                                          OperandRange varOperands) {
376   printer << operand;
377   if (optOperand)
378     printer << ", " << optOperand;
379   printer << " -> (" << varOperands << ")";
380 }
381 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
382                                         Type operandType, Type optOperandType,
383                                         TypeRange varOperandTypes) {
384   printer << " : " << operandType;
385   if (optOperandType)
386     printer << ", " << optOperandType;
387   printer << " -> (" << varOperandTypes << ")";
388 }
389 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
390                                              Operation *op, Type operandType,
391                                              Type optOperandType,
392                                              TypeRange varOperandTypes) {
393   printer << " type_refs_capture ";
394   printCustomDirectiveResults(printer, op, operandType, optOperandType,
395                               varOperandTypes);
396 }
397 static void printCustomDirectiveOperandsAndTypes(
398     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
399     OperandRange varOperands, Type operandType, Type optOperandType,
400     TypeRange varOperandTypes) {
401   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
402   printCustomDirectiveResults(printer, op, operandType, optOperandType,
403                               varOperandTypes);
404 }
405 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
406                                         Region &region,
407                                         MutableArrayRef<Region> varRegions) {
408   printer.printRegion(region);
409   if (!varRegions.empty()) {
410     printer << ", ";
411     for (Region &region : varRegions)
412       printer.printRegion(region);
413   }
414 }
415 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
416                                            Block *successor,
417                                            SuccessorRange varSuccessors) {
418   printer << successor;
419   if (!varSuccessors.empty())
420     printer << ", " << varSuccessors.front();
421 }
422 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
423                                            Attribute attribute,
424                                            Attribute optAttribute) {
425   printer << attribute;
426   if (optAttribute)
427     printer << ", " << optAttribute;
428 }
429 
430 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
431                                          DictionaryAttr attrs) {
432   printer.printOptionalAttrDict(attrs.getValue());
433 }
434 
435 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
436                                                    Operation *op,
437                                                    Value optOperand) {
438   printer << (optOperand ? "1" : "0");
439 }
440 
441 //===----------------------------------------------------------------------===//
442 // Test IsolatedRegionOp - parse passthrough region arguments.
443 //===----------------------------------------------------------------------===//
444 
445 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
446                                          OperationState &result) {
447   OpAsmParser::OperandType argInfo;
448   Type argType = parser.getBuilder().getIndexType();
449 
450   // Parse the input operand.
451   if (parser.parseOperand(argInfo) ||
452       parser.resolveOperand(argInfo, argType, result.operands))
453     return failure();
454 
455   // Parse the body region, and reuse the operand info as the argument info.
456   Region *body = result.addRegion();
457   return parser.parseRegion(*body, argInfo, argType,
458                             /*enableNameShadowing=*/true);
459 }
460 
461 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
462   p << "test.isolated_region ";
463   p.printOperand(op.getOperand());
464   p.shadowRegionArgs(op.region(), op.getOperand());
465   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
466 }
467 
468 //===----------------------------------------------------------------------===//
469 // Test SSACFGRegionOp
470 //===----------------------------------------------------------------------===//
471 
472 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
473   return RegionKind::SSACFG;
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // Test GraphRegionOp
478 //===----------------------------------------------------------------------===//
479 
480 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
481                                       OperationState &result) {
482   // Parse the body region, and reuse the operand info as the argument info.
483   Region *body = result.addRegion();
484   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
485 }
486 
487 static void print(OpAsmPrinter &p, GraphRegionOp op) {
488   p << "test.graph_region ";
489   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
490 }
491 
492 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
493   return RegionKind::Graph;
494 }
495 
496 //===----------------------------------------------------------------------===//
497 // Test AffineScopeOp
498 //===----------------------------------------------------------------------===//
499 
500 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
501                                       OperationState &result) {
502   // Parse the body region, and reuse the operand info as the argument info.
503   Region *body = result.addRegion();
504   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
505 }
506 
507 static void print(OpAsmPrinter &p, AffineScopeOp op) {
508   p << "test.affine_scope ";
509   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // Test parser.
514 //===----------------------------------------------------------------------===//
515 
516 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
517                                               OperationState &result) {
518   if (parser.parseOptionalColon())
519     return success();
520   uint64_t numResults;
521   if (parser.parseInteger(numResults))
522     return failure();
523 
524   IndexType type = parser.getBuilder().getIndexType();
525   for (unsigned i = 0; i < numResults; ++i)
526     result.addTypes(type);
527   return success();
528 }
529 
530 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
531   p << ParseIntegerLiteralOp::getOperationName();
532   if (unsigned numResults = op->getNumResults())
533     p << " : " << numResults;
534 }
535 
536 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
537                                               OperationState &result) {
538   StringRef keyword;
539   if (parser.parseKeyword(&keyword))
540     return failure();
541   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
542   return success();
543 }
544 
545 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
546   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
551 
552 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
553                                          OperationState &result) {
554   if (parser.parseKeyword("wraps"))
555     return failure();
556 
557   // Parse the wrapped op in a region
558   Region &body = *result.addRegion();
559   body.push_back(new Block);
560   Block &block = body.back();
561   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
562   if (!wrapped_op)
563     return failure();
564 
565   // Create a return terminator in the inner region, pass as operand to the
566   // terminator the returned values from the wrapped operation.
567   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
568   OpBuilder builder(parser.getBuilder().getContext());
569   builder.setInsertionPointToEnd(&block);
570   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
571 
572   // Get the results type for the wrapping op from the terminator operands.
573   Operation &return_op = body.back().back();
574   result.types.append(return_op.operand_type_begin(),
575                       return_op.operand_type_end());
576 
577   // Use the location of the wrapped op for the "test.wrapping_region" op.
578   result.location = wrapped_op->getLoc();
579 
580   return success();
581 }
582 
583 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
584   p << op.getOperationName() << " wraps ";
585   p.printGenericOp(&op.region().front().front());
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // Test PolyForOp - parse list of region arguments.
590 //===----------------------------------------------------------------------===//
591 
592 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
593   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
594   // Parse list of region arguments without a delimiter.
595   if (parser.parseRegionArgumentList(ivsInfo))
596     return failure();
597 
598   // Parse the body region.
599   Region *body = result.addRegion();
600   auto &builder = parser.getBuilder();
601   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
602   return parser.parseRegion(*body, ivsInfo, argTypes);
603 }
604 
605 //===----------------------------------------------------------------------===//
606 // Test removing op with inner ops.
607 //===----------------------------------------------------------------------===//
608 
609 namespace {
610 struct TestRemoveOpWithInnerOps
611     : public OpRewritePattern<TestOpWithRegionPattern> {
612   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
613 
614   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
615                                 PatternRewriter &rewriter) const override {
616     rewriter.eraseOp(op);
617     return success();
618   }
619 };
620 } // end anonymous namespace
621 
622 void TestOpWithRegionPattern::getCanonicalizationPatterns(
623     OwningRewritePatternList &results, MLIRContext *context) {
624   results.insert<TestRemoveOpWithInnerOps>(context);
625 }
626 
627 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
628   return operand();
629 }
630 
631 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
632   return getValue();
633 }
634 
635 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
636     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
637   for (Value input : this->operands()) {
638     results.push_back(input);
639   }
640   return success();
641 }
642 
643 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
644   assert(operands.size() == 1);
645   if (operands.front()) {
646     (*this)->setAttr("attr", operands.front());
647     return getResult();
648   }
649   return {};
650 }
651 
652 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
653   return getOperand();
654 }
655 
656 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
657     MLIRContext *, Optional<Location> location, ValueRange operands,
658     DictionaryAttr attributes, RegionRange regions,
659     SmallVectorImpl<Type> &inferredReturnTypes) {
660   if (operands[0].getType() != operands[1].getType()) {
661     return emitOptionalError(location, "operand type mismatch ",
662                              operands[0].getType(), " vs ",
663                              operands[1].getType());
664   }
665   inferredReturnTypes.assign({operands[0].getType()});
666   return success();
667 }
668 
669 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
670     MLIRContext *context, Optional<Location> location, ValueRange operands,
671     DictionaryAttr attributes, RegionRange regions,
672     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
673   // Create return type consisting of the last element of the first operand.
674   auto operandType = *operands.getTypes().begin();
675   auto sval = operandType.dyn_cast<ShapedType>();
676   if (!sval) {
677     return emitOptionalError(location, "only shaped type operands allowed");
678   }
679   int64_t dim =
680       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
681   auto type = IntegerType::get(context, 17);
682   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
683   return success();
684 }
685 
686 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
687     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
688   shapes = SmallVector<Value, 1>{
689       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
690   return success();
691 }
692 
693 //===----------------------------------------------------------------------===//
694 // Test SideEffect interfaces
695 //===----------------------------------------------------------------------===//
696 
697 namespace {
698 /// A test resource for side effects.
699 struct TestResource : public SideEffects::Resource::Base<TestResource> {
700   StringRef getName() final { return "<Test>"; }
701 };
702 } // end anonymous namespace
703 
704 void SideEffectOp::getEffects(
705     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
706   // Check for an effects attribute on the op instance.
707   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
708   if (!effectsAttr)
709     return;
710 
711   // If there is one, it is an array of dictionary attributes that hold
712   // information on the effects of this operation.
713   for (Attribute element : effectsAttr) {
714     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
715 
716     // Get the specific memory effect.
717     MemoryEffects::Effect *effect =
718         StringSwitch<MemoryEffects::Effect *>(
719             effectElement.get("effect").cast<StringAttr>().getValue())
720             .Case("allocate", MemoryEffects::Allocate::get())
721             .Case("free", MemoryEffects::Free::get())
722             .Case("read", MemoryEffects::Read::get())
723             .Case("write", MemoryEffects::Write::get());
724 
725     // Check for a non-default resource to use.
726     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
727     if (effectElement.get("test_resource"))
728       resource = TestResource::get();
729 
730     // Check for a result to affect.
731     if (effectElement.get("on_result"))
732       effects.emplace_back(effect, getResult(), resource);
733     else if (Attribute ref = effectElement.get("on_reference"))
734       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
735     else
736       effects.emplace_back(effect, resource);
737   }
738 }
739 
740 void SideEffectOp::getEffects(
741     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
742   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
743   if (!effectsAttr)
744     return;
745 
746   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // StringAttrPrettyNameOp
751 //===----------------------------------------------------------------------===//
752 
753 // This op has fancy handling of its SSA result name.
754 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
755                                                OperationState &result) {
756   // Add the result types.
757   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
758     result.addTypes(parser.getBuilder().getIntegerType(32));
759 
760   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
761     return failure();
762 
763   // If the attribute dictionary contains no 'names' attribute, infer it from
764   // the SSA name (if specified).
765   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
766     return attr.first == "names";
767   });
768 
769   // If there was no name specified, check to see if there was a useful name
770   // specified in the asm file.
771   if (hadNames || parser.getNumResults() == 0)
772     return success();
773 
774   SmallVector<StringRef, 4> names;
775   auto *context = result.getContext();
776 
777   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
778     auto resultName = parser.getResultName(i);
779     StringRef nameStr;
780     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
781       nameStr = resultName.first;
782 
783     names.push_back(nameStr);
784   }
785 
786   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
787   result.attributes.push_back({Identifier::get("names", context), namesAttr});
788   return success();
789 }
790 
791 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
792   p << "test.string_attr_pretty_name";
793 
794   // Note that we only need to print the "name" attribute if the asmprinter
795   // result name disagrees with it.  This can happen in strange cases, e.g.
796   // when there are conflicts.
797   bool namesDisagree = op.names().size() != op.getNumResults();
798 
799   SmallString<32> resultNameStr;
800   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
801     resultNameStr.clear();
802     llvm::raw_svector_ostream tmpStream(resultNameStr);
803     p.printOperand(op.getResult(i), tmpStream);
804 
805     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
806     if (!expectedName ||
807         tmpStream.str().drop_front() != expectedName.getValue()) {
808       namesDisagree = true;
809     }
810   }
811 
812   if (namesDisagree)
813     p.printOptionalAttrDictWithKeyword(op->getAttrs());
814   else
815     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
816 }
817 
818 // We set the SSA name in the asm syntax to the contents of the name
819 // attribute.
820 void StringAttrPrettyNameOp::getAsmResultNames(
821     function_ref<void(Value, StringRef)> setNameFn) {
822 
823   auto value = names();
824   for (size_t i = 0, e = value.size(); i != e; ++i)
825     if (auto str = value[i].dyn_cast<StringAttr>())
826       if (!str.getValue().empty())
827         setNameFn(getResult(i), str.getValue());
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // RegionIfOp
832 //===----------------------------------------------------------------------===//
833 
834 static void print(OpAsmPrinter &p, RegionIfOp op) {
835   p << RegionIfOp::getOperationName() << " ";
836   p.printOperands(op.getOperands());
837   p << ": " << op.getOperandTypes();
838   p.printArrowTypeList(op.getResultTypes());
839   p << " then";
840   p.printRegion(op.thenRegion(),
841                 /*printEntryBlockArgs=*/true,
842                 /*printBlockTerminators=*/true);
843   p << " else";
844   p.printRegion(op.elseRegion(),
845                 /*printEntryBlockArgs=*/true,
846                 /*printBlockTerminators=*/true);
847   p << " join";
848   p.printRegion(op.joinRegion(),
849                 /*printEntryBlockArgs=*/true,
850                 /*printBlockTerminators=*/true);
851 }
852 
853 static ParseResult parseRegionIfOp(OpAsmParser &parser,
854                                    OperationState &result) {
855   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
856   SmallVector<Type, 2> operandTypes;
857 
858   result.regions.reserve(3);
859   Region *thenRegion = result.addRegion();
860   Region *elseRegion = result.addRegion();
861   Region *joinRegion = result.addRegion();
862 
863   // Parse operand, type and arrow type lists.
864   if (parser.parseOperandList(operandInfos) ||
865       parser.parseColonTypeList(operandTypes) ||
866       parser.parseArrowTypeList(result.types))
867     return failure();
868 
869   // Parse all attached regions.
870   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
871       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
872       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
873     return failure();
874 
875   return parser.resolveOperands(operandInfos, operandTypes,
876                                 parser.getCurrentLocation(), result.operands);
877 }
878 
879 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
880   assert(index < 2 && "invalid region index");
881   return getOperands();
882 }
883 
884 void RegionIfOp::getSuccessorRegions(
885     Optional<unsigned> index, ArrayRef<Attribute> operands,
886     SmallVectorImpl<RegionSuccessor> &regions) {
887   // We always branch to the join region.
888   if (index.hasValue()) {
889     if (index.getValue() < 2)
890       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
891     else
892       regions.push_back(RegionSuccessor(getResults()));
893     return;
894   }
895 
896   // The then and else regions are the entry regions of this op.
897   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
898   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
899 }
900 
901 #include "TestOpEnums.cpp.inc"
902 #include "TestOpInterfaces.cpp.inc"
903 #include "TestOpStructs.cpp.inc"
904 #include "TestTypeInterfaces.cpp.inc"
905 
906 #define GET_OP_CLASSES
907 #include "TestOps.cpp.inc"
908