1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 #include "mlir/Transforms/FoldUtils.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/StringSwitch.h"
19 
20 using namespace mlir;
21 using namespace mlir::test;
22 
23 void mlir::test::registerTestDialect(DialectRegistry &registry) {
24   registry.insert<TestDialect>();
25 }
26 
27 //===----------------------------------------------------------------------===//
28 // TestDialect Interfaces
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 
33 // Test support for interacting with the AsmPrinter.
34 struct TestOpAsmInterface : public OpAsmDialectInterface {
35   using OpAsmDialectInterface::OpAsmDialectInterface;
36 
37   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
38     StringAttr strAttr = attr.dyn_cast<StringAttr>();
39     if (!strAttr)
40       return failure();
41 
42     // Check the contents of the string attribute to see what the test alias
43     // should be named.
44     Optional<StringRef> aliasName =
45         StringSwitch<Optional<StringRef>>(strAttr.getValue())
46             .Case("alias_test:dot_in_name", StringRef("test.alias"))
47             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
48             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
49             .Case("alias_test:sanitize_conflict_a",
50                   StringRef("test_alias_conflict0"))
51             .Case("alias_test:sanitize_conflict_b",
52                   StringRef("test_alias_conflict0_"))
53             .Default(llvm::None);
54     if (!aliasName)
55       return failure();
56 
57     os << *aliasName;
58     return success();
59   }
60 
61   void getAsmResultNames(Operation *op,
62                          OpAsmSetValueNameFn setNameFn) const final {
63     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
64       setNameFn(asmOp, "result");
65   }
66 
67   void getAsmBlockArgumentNames(Block *block,
68                                 OpAsmSetValueNameFn setNameFn) const final {
69     auto op = block->getParentOp();
70     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
71     if (!arrayAttr)
72       return;
73     auto args = block->getArguments();
74     auto e = std::min(arrayAttr.size(), args.size());
75     for (unsigned i = 0; i < e; ++i) {
76       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
77         setNameFn(args[i], strAttr.getValue());
78     }
79   }
80 };
81 
82 struct TestDialectFoldInterface : public DialectFoldInterface {
83   using DialectFoldInterface::DialectFoldInterface;
84 
85   /// Registered hook to check if the given region, which is attached to an
86   /// operation that is *not* isolated from above, should be used when
87   /// materializing constants.
88   bool shouldMaterializeInto(Region *region) const final {
89     // If this is a one region operation, then insert into it.
90     return isa<OneRegionOp>(region->getParentOp());
91   }
92 };
93 
94 /// This class defines the interface for handling inlining with standard
95 /// operations.
96 struct TestInlinerInterface : public DialectInlinerInterface {
97   using DialectInlinerInterface::DialectInlinerInterface;
98 
99   //===--------------------------------------------------------------------===//
100   // Analysis Hooks
101   //===--------------------------------------------------------------------===//
102 
103   bool isLegalToInline(Operation *call, Operation *callable,
104                        bool wouldBeCloned) const final {
105     // Don't allow inlining calls that are marked `noinline`.
106     return !call->hasAttr("noinline");
107   }
108   bool isLegalToInline(Region *, Region *, bool,
109                        BlockAndValueMapping &) const final {
110     // Inlining into test dialect regions is legal.
111     return true;
112   }
113   bool isLegalToInline(Operation *, Region *, bool,
114                        BlockAndValueMapping &) const final {
115     return true;
116   }
117 
118   bool shouldAnalyzeRecursively(Operation *op) const final {
119     // Analyze recursively if this is not a functional region operation, it
120     // froms a separate functional scope.
121     return !isa<FunctionalRegionOp>(op);
122   }
123 
124   //===--------------------------------------------------------------------===//
125   // Transformation Hooks
126   //===--------------------------------------------------------------------===//
127 
128   /// Handle the given inlined terminator by replacing it with a new operation
129   /// as necessary.
130   void handleTerminator(Operation *op,
131                         ArrayRef<Value> valuesToRepl) const final {
132     // Only handle "test.return" here.
133     auto returnOp = dyn_cast<TestReturnOp>(op);
134     if (!returnOp)
135       return;
136 
137     // Replace the values directly with the return operands.
138     assert(returnOp.getNumOperands() == valuesToRepl.size());
139     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
140       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
141   }
142 
143   /// Attempt to materialize a conversion for a type mismatch between a call
144   /// from this dialect, and a callable region. This method should generate an
145   /// operation that takes 'input' as the only operand, and produces a single
146   /// result of 'resultType'. If a conversion can not be generated, nullptr
147   /// should be returned.
148   Operation *materializeCallConversion(OpBuilder &builder, Value input,
149                                        Type resultType,
150                                        Location conversionLoc) const final {
151     // Only allow conversion for i16/i32 types.
152     if (!(resultType.isSignlessInteger(16) ||
153           resultType.isSignlessInteger(32)) ||
154         !(input.getType().isSignlessInteger(16) ||
155           input.getType().isSignlessInteger(32)))
156       return nullptr;
157     return builder.create<TestCastOp>(conversionLoc, resultType, input);
158   }
159 };
160 } // end anonymous namespace
161 
162 //===----------------------------------------------------------------------===//
163 // TestDialect
164 //===----------------------------------------------------------------------===//
165 
166 void TestDialect::initialize() {
167   addOperations<
168 #define GET_OP_LIST
169 #include "TestOps.cpp.inc"
170       >();
171   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
172                 TestInlinerInterface>();
173   addTypes<TestType, TestRecursiveType,
174 #define GET_TYPEDEF_LIST
175 #include "TestTypeDefs.cpp.inc"
176            >();
177   allowUnknownOperations();
178 }
179 
180 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
181                                             Type type, Location loc) {
182   return builder.create<TestOpConstant>(loc, type, value);
183 }
184 
185 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
186                                                     NamedAttribute namedAttr) {
187   if (namedAttr.first == "test.invalid_attr")
188     return op->emitError() << "invalid to use 'test.invalid_attr'";
189   return success();
190 }
191 
192 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
193                                                     unsigned regionIndex,
194                                                     unsigned argIndex,
195                                                     NamedAttribute namedAttr) {
196   if (namedAttr.first == "test.invalid_attr")
197     return op->emitError() << "invalid to use 'test.invalid_attr'";
198   return success();
199 }
200 
201 LogicalResult
202 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
203                                          unsigned resultIndex,
204                                          NamedAttribute namedAttr) {
205   if (namedAttr.first == "test.invalid_attr")
206     return op->emitError() << "invalid to use 'test.invalid_attr'";
207   return success();
208 }
209 
210 //===----------------------------------------------------------------------===//
211 // TestBranchOp
212 //===----------------------------------------------------------------------===//
213 
214 Optional<MutableOperandRange>
215 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
216   assert(index == 0 && "invalid successor index");
217   return targetOperandsMutable();
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // TestFoldToCallOp
222 //===----------------------------------------------------------------------===//
223 
224 namespace {
225 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
226   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
227 
228   LogicalResult matchAndRewrite(FoldToCallOp op,
229                                 PatternRewriter &rewriter) const override {
230     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
231                                         ValueRange());
232     return success();
233   }
234 };
235 } // end anonymous namespace
236 
237 void FoldToCallOp::getCanonicalizationPatterns(
238     OwningRewritePatternList &results, MLIRContext *context) {
239   results.insert<FoldToCallOpPattern>(context);
240 }
241 
242 //===----------------------------------------------------------------------===//
243 // Test Format* operations
244 //===----------------------------------------------------------------------===//
245 
246 //===----------------------------------------------------------------------===//
247 // Parsing
248 
249 static ParseResult parseCustomDirectiveOperands(
250     OpAsmParser &parser, OpAsmParser::OperandType &operand,
251     Optional<OpAsmParser::OperandType> &optOperand,
252     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
253   if (parser.parseOperand(operand))
254     return failure();
255   if (succeeded(parser.parseOptionalComma())) {
256     optOperand.emplace();
257     if (parser.parseOperand(*optOperand))
258       return failure();
259   }
260   if (parser.parseArrow() || parser.parseLParen() ||
261       parser.parseOperandList(varOperands) || parser.parseRParen())
262     return failure();
263   return success();
264 }
265 static ParseResult
266 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
267                             Type &optOperandType,
268                             SmallVectorImpl<Type> &varOperandTypes) {
269   if (parser.parseColon())
270     return failure();
271 
272   if (parser.parseType(operandType))
273     return failure();
274   if (succeeded(parser.parseOptionalComma())) {
275     if (parser.parseType(optOperandType))
276       return failure();
277   }
278   if (parser.parseArrow() || parser.parseLParen() ||
279       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
280     return failure();
281   return success();
282 }
283 static ParseResult
284 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
285                                  Type optOperandType,
286                                  const SmallVectorImpl<Type> &varOperandTypes) {
287   if (parser.parseKeyword("type_refs_capture"))
288     return failure();
289 
290   Type operandType2, optOperandType2;
291   SmallVector<Type, 1> varOperandTypes2;
292   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
293                                   varOperandTypes2))
294     return failure();
295 
296   if (operandType != operandType2 || optOperandType != optOperandType2 ||
297       varOperandTypes != varOperandTypes2)
298     return failure();
299 
300   return success();
301 }
302 static ParseResult parseCustomDirectiveOperandsAndTypes(
303     OpAsmParser &parser, OpAsmParser::OperandType &operand,
304     Optional<OpAsmParser::OperandType> &optOperand,
305     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
306     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
307   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
308       parseCustomDirectiveResults(parser, operandType, optOperandType,
309                                   varOperandTypes))
310     return failure();
311   return success();
312 }
313 static ParseResult parseCustomDirectiveRegions(
314     OpAsmParser &parser, Region &region,
315     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
316   if (parser.parseRegion(region))
317     return failure();
318   if (failed(parser.parseOptionalComma()))
319     return success();
320   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
321   if (parser.parseRegion(*varRegion))
322     return failure();
323   varRegions.emplace_back(std::move(varRegion));
324   return success();
325 }
326 static ParseResult
327 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
328                                SmallVectorImpl<Block *> &varSuccessors) {
329   if (parser.parseSuccessor(successor))
330     return failure();
331   if (failed(parser.parseOptionalComma()))
332     return success();
333   Block *varSuccessor;
334   if (parser.parseSuccessor(varSuccessor))
335     return failure();
336   varSuccessors.append(2, varSuccessor);
337   return success();
338 }
339 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
340                                                   IntegerAttr &attr,
341                                                   IntegerAttr &optAttr) {
342   if (parser.parseAttribute(attr))
343     return failure();
344   if (succeeded(parser.parseOptionalComma())) {
345     if (parser.parseAttribute(optAttr))
346       return failure();
347   }
348   return success();
349 }
350 
351 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
352                                                 NamedAttrList &attrs) {
353   return parser.parseOptionalAttrDict(attrs);
354 }
355 static ParseResult parseCustomDirectiveOptionalOperandRef(
356     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
357   int64_t operandCount = 0;
358   if (parser.parseInteger(operandCount))
359     return failure();
360   bool expectedOptionalOperand = operandCount == 0;
361   return success(expectedOptionalOperand != optOperand.hasValue());
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // Printing
366 
367 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
368                                          Value operand, Value optOperand,
369                                          OperandRange varOperands) {
370   printer << operand;
371   if (optOperand)
372     printer << ", " << optOperand;
373   printer << " -> (" << varOperands << ")";
374 }
375 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
376                                         Type operandType, Type optOperandType,
377                                         TypeRange varOperandTypes) {
378   printer << " : " << operandType;
379   if (optOperandType)
380     printer << ", " << optOperandType;
381   printer << " -> (" << varOperandTypes << ")";
382 }
383 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
384                                              Operation *op, Type operandType,
385                                              Type optOperandType,
386                                              TypeRange varOperandTypes) {
387   printer << " type_refs_capture ";
388   printCustomDirectiveResults(printer, op, operandType, optOperandType,
389                               varOperandTypes);
390 }
391 static void printCustomDirectiveOperandsAndTypes(
392     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
393     OperandRange varOperands, Type operandType, Type optOperandType,
394     TypeRange varOperandTypes) {
395   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
396   printCustomDirectiveResults(printer, op, operandType, optOperandType,
397                               varOperandTypes);
398 }
399 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
400                                         Region &region,
401                                         MutableArrayRef<Region> varRegions) {
402   printer.printRegion(region);
403   if (!varRegions.empty()) {
404     printer << ", ";
405     for (Region &region : varRegions)
406       printer.printRegion(region);
407   }
408 }
409 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
410                                            Block *successor,
411                                            SuccessorRange varSuccessors) {
412   printer << successor;
413   if (!varSuccessors.empty())
414     printer << ", " << varSuccessors.front();
415 }
416 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
417                                            Attribute attribute,
418                                            Attribute optAttribute) {
419   printer << attribute;
420   if (optAttribute)
421     printer << ", " << optAttribute;
422 }
423 
424 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
425                                          DictionaryAttr attrs) {
426   printer.printOptionalAttrDict(attrs.getValue());
427 }
428 
429 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
430                                                    Operation *op,
431                                                    Value optOperand) {
432   printer << (optOperand ? "1" : "0");
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // Test IsolatedRegionOp - parse passthrough region arguments.
437 //===----------------------------------------------------------------------===//
438 
439 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
440                                          OperationState &result) {
441   OpAsmParser::OperandType argInfo;
442   Type argType = parser.getBuilder().getIndexType();
443 
444   // Parse the input operand.
445   if (parser.parseOperand(argInfo) ||
446       parser.resolveOperand(argInfo, argType, result.operands))
447     return failure();
448 
449   // Parse the body region, and reuse the operand info as the argument info.
450   Region *body = result.addRegion();
451   return parser.parseRegion(*body, argInfo, argType,
452                             /*enableNameShadowing=*/true);
453 }
454 
455 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
456   p << "test.isolated_region ";
457   p.printOperand(op.getOperand());
458   p.shadowRegionArgs(op.region(), op.getOperand());
459   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // Test SSACFGRegionOp
464 //===----------------------------------------------------------------------===//
465 
466 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
467   return RegionKind::SSACFG;
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // Test GraphRegionOp
472 //===----------------------------------------------------------------------===//
473 
474 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
475                                       OperationState &result) {
476   // Parse the body region, and reuse the operand info as the argument info.
477   Region *body = result.addRegion();
478   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
479 }
480 
481 static void print(OpAsmPrinter &p, GraphRegionOp op) {
482   p << "test.graph_region ";
483   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
484 }
485 
486 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
487   return RegionKind::Graph;
488 }
489 
490 //===----------------------------------------------------------------------===//
491 // Test AffineScopeOp
492 //===----------------------------------------------------------------------===//
493 
494 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
495                                       OperationState &result) {
496   // Parse the body region, and reuse the operand info as the argument info.
497   Region *body = result.addRegion();
498   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
499 }
500 
501 static void print(OpAsmPrinter &p, AffineScopeOp op) {
502   p << "test.affine_scope ";
503   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // Test parser.
508 //===----------------------------------------------------------------------===//
509 
510 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
511                                               OperationState &result) {
512   if (parser.parseOptionalColon())
513     return success();
514   uint64_t numResults;
515   if (parser.parseInteger(numResults))
516     return failure();
517 
518   IndexType type = parser.getBuilder().getIndexType();
519   for (unsigned i = 0; i < numResults; ++i)
520     result.addTypes(type);
521   return success();
522 }
523 
524 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
525   p << ParseIntegerLiteralOp::getOperationName();
526   if (unsigned numResults = op->getNumResults())
527     p << " : " << numResults;
528 }
529 
530 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
531                                               OperationState &result) {
532   StringRef keyword;
533   if (parser.parseKeyword(&keyword))
534     return failure();
535   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
536   return success();
537 }
538 
539 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
540   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
545 
546 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
547                                          OperationState &result) {
548   if (parser.parseKeyword("wraps"))
549     return failure();
550 
551   // Parse the wrapped op in a region
552   Region &body = *result.addRegion();
553   body.push_back(new Block);
554   Block &block = body.back();
555   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
556   if (!wrapped_op)
557     return failure();
558 
559   // Create a return terminator in the inner region, pass as operand to the
560   // terminator the returned values from the wrapped operation.
561   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
562   OpBuilder builder(parser.getBuilder().getContext());
563   builder.setInsertionPointToEnd(&block);
564   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
565 
566   // Get the results type for the wrapping op from the terminator operands.
567   Operation &return_op = body.back().back();
568   result.types.append(return_op.operand_type_begin(),
569                       return_op.operand_type_end());
570 
571   // Use the location of the wrapped op for the "test.wrapping_region" op.
572   result.location = wrapped_op->getLoc();
573 
574   return success();
575 }
576 
577 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
578   p << op.getOperationName() << " wraps ";
579   p.printGenericOp(&op.region().front().front());
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // Test PolyForOp - parse list of region arguments.
584 //===----------------------------------------------------------------------===//
585 
586 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
587   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
588   // Parse list of region arguments without a delimiter.
589   if (parser.parseRegionArgumentList(ivsInfo))
590     return failure();
591 
592   // Parse the body region.
593   Region *body = result.addRegion();
594   auto &builder = parser.getBuilder();
595   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
596   return parser.parseRegion(*body, ivsInfo, argTypes);
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // Test removing op with inner ops.
601 //===----------------------------------------------------------------------===//
602 
603 namespace {
604 struct TestRemoveOpWithInnerOps
605     : public OpRewritePattern<TestOpWithRegionPattern> {
606   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
607 
608   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
609                                 PatternRewriter &rewriter) const override {
610     rewriter.eraseOp(op);
611     return success();
612   }
613 };
614 } // end anonymous namespace
615 
616 void TestOpWithRegionPattern::getCanonicalizationPatterns(
617     OwningRewritePatternList &results, MLIRContext *context) {
618   results.insert<TestRemoveOpWithInnerOps>(context);
619 }
620 
621 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
622   return operand();
623 }
624 
625 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
626   return getValue();
627 }
628 
629 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
630     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
631   for (Value input : this->operands()) {
632     results.push_back(input);
633   }
634   return success();
635 }
636 
637 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
638   assert(operands.size() == 1);
639   if (operands.front()) {
640     (*this)->setAttr("attr", operands.front());
641     return getResult();
642   }
643   return {};
644 }
645 
646 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
647   return getOperand();
648 }
649 
650 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
651     MLIRContext *, Optional<Location> location, ValueRange operands,
652     DictionaryAttr attributes, RegionRange regions,
653     SmallVectorImpl<Type> &inferredReturnTypes) {
654   if (operands[0].getType() != operands[1].getType()) {
655     return emitOptionalError(location, "operand type mismatch ",
656                              operands[0].getType(), " vs ",
657                              operands[1].getType());
658   }
659   inferredReturnTypes.assign({operands[0].getType()});
660   return success();
661 }
662 
663 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
664     MLIRContext *context, Optional<Location> location, ValueRange operands,
665     DictionaryAttr attributes, RegionRange regions,
666     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
667   // Create return type consisting of the last element of the first operand.
668   auto operandType = *operands.getTypes().begin();
669   auto sval = operandType.dyn_cast<ShapedType>();
670   if (!sval) {
671     return emitOptionalError(location, "only shaped type operands allowed");
672   }
673   int64_t dim =
674       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
675   auto type = IntegerType::get(context, 17);
676   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
677   return success();
678 }
679 
680 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
681     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
682   shapes = SmallVector<Value, 1>{
683       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
684   return success();
685 }
686 
687 //===----------------------------------------------------------------------===//
688 // Test SideEffect interfaces
689 //===----------------------------------------------------------------------===//
690 
691 namespace {
692 /// A test resource for side effects.
693 struct TestResource : public SideEffects::Resource::Base<TestResource> {
694   StringRef getName() final { return "<Test>"; }
695 };
696 } // end anonymous namespace
697 
698 void SideEffectOp::getEffects(
699     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
700   // Check for an effects attribute on the op instance.
701   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
702   if (!effectsAttr)
703     return;
704 
705   // If there is one, it is an array of dictionary attributes that hold
706   // information on the effects of this operation.
707   for (Attribute element : effectsAttr) {
708     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
709 
710     // Get the specific memory effect.
711     MemoryEffects::Effect *effect =
712         StringSwitch<MemoryEffects::Effect *>(
713             effectElement.get("effect").cast<StringAttr>().getValue())
714             .Case("allocate", MemoryEffects::Allocate::get())
715             .Case("free", MemoryEffects::Free::get())
716             .Case("read", MemoryEffects::Read::get())
717             .Case("write", MemoryEffects::Write::get());
718 
719     // Check for a non-default resource to use.
720     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
721     if (effectElement.get("test_resource"))
722       resource = TestResource::get();
723 
724     // Check for a result to affect.
725     if (effectElement.get("on_result"))
726       effects.emplace_back(effect, getResult(), resource);
727     else if (Attribute ref = effectElement.get("on_reference"))
728       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
729     else
730       effects.emplace_back(effect, resource);
731   }
732 }
733 
734 void SideEffectOp::getEffects(
735     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
736   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
737   if (!effectsAttr)
738     return;
739 
740   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // StringAttrPrettyNameOp
745 //===----------------------------------------------------------------------===//
746 
747 // This op has fancy handling of its SSA result name.
748 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
749                                                OperationState &result) {
750   // Add the result types.
751   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
752     result.addTypes(parser.getBuilder().getIntegerType(32));
753 
754   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
755     return failure();
756 
757   // If the attribute dictionary contains no 'names' attribute, infer it from
758   // the SSA name (if specified).
759   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
760     return attr.first == "names";
761   });
762 
763   // If there was no name specified, check to see if there was a useful name
764   // specified in the asm file.
765   if (hadNames || parser.getNumResults() == 0)
766     return success();
767 
768   SmallVector<StringRef, 4> names;
769   auto *context = result.getContext();
770 
771   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
772     auto resultName = parser.getResultName(i);
773     StringRef nameStr;
774     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
775       nameStr = resultName.first;
776 
777     names.push_back(nameStr);
778   }
779 
780   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
781   result.attributes.push_back({Identifier::get("names", context), namesAttr});
782   return success();
783 }
784 
785 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
786   p << "test.string_attr_pretty_name";
787 
788   // Note that we only need to print the "name" attribute if the asmprinter
789   // result name disagrees with it.  This can happen in strange cases, e.g.
790   // when there are conflicts.
791   bool namesDisagree = op.names().size() != op.getNumResults();
792 
793   SmallString<32> resultNameStr;
794   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
795     resultNameStr.clear();
796     llvm::raw_svector_ostream tmpStream(resultNameStr);
797     p.printOperand(op.getResult(i), tmpStream);
798 
799     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
800     if (!expectedName ||
801         tmpStream.str().drop_front() != expectedName.getValue()) {
802       namesDisagree = true;
803     }
804   }
805 
806   if (namesDisagree)
807     p.printOptionalAttrDictWithKeyword(op->getAttrs());
808   else
809     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
810 }
811 
812 // We set the SSA name in the asm syntax to the contents of the name
813 // attribute.
814 void StringAttrPrettyNameOp::getAsmResultNames(
815     function_ref<void(Value, StringRef)> setNameFn) {
816 
817   auto value = names();
818   for (size_t i = 0, e = value.size(); i != e; ++i)
819     if (auto str = value[i].dyn_cast<StringAttr>())
820       if (!str.getValue().empty())
821         setNameFn(getResult(i), str.getValue());
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // RegionIfOp
826 //===----------------------------------------------------------------------===//
827 
828 static void print(OpAsmPrinter &p, RegionIfOp op) {
829   p << RegionIfOp::getOperationName() << " ";
830   p.printOperands(op.getOperands());
831   p << ": " << op.getOperandTypes();
832   p.printArrowTypeList(op.getResultTypes());
833   p << " then";
834   p.printRegion(op.thenRegion(),
835                 /*printEntryBlockArgs=*/true,
836                 /*printBlockTerminators=*/true);
837   p << " else";
838   p.printRegion(op.elseRegion(),
839                 /*printEntryBlockArgs=*/true,
840                 /*printBlockTerminators=*/true);
841   p << " join";
842   p.printRegion(op.joinRegion(),
843                 /*printEntryBlockArgs=*/true,
844                 /*printBlockTerminators=*/true);
845 }
846 
847 static ParseResult parseRegionIfOp(OpAsmParser &parser,
848                                    OperationState &result) {
849   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
850   SmallVector<Type, 2> operandTypes;
851 
852   result.regions.reserve(3);
853   Region *thenRegion = result.addRegion();
854   Region *elseRegion = result.addRegion();
855   Region *joinRegion = result.addRegion();
856 
857   // Parse operand, type and arrow type lists.
858   if (parser.parseOperandList(operandInfos) ||
859       parser.parseColonTypeList(operandTypes) ||
860       parser.parseArrowTypeList(result.types))
861     return failure();
862 
863   // Parse all attached regions.
864   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
865       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
866       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
867     return failure();
868 
869   return parser.resolveOperands(operandInfos, operandTypes,
870                                 parser.getCurrentLocation(), result.operands);
871 }
872 
873 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
874   assert(index < 2 && "invalid region index");
875   return getOperands();
876 }
877 
878 void RegionIfOp::getSuccessorRegions(
879     Optional<unsigned> index, ArrayRef<Attribute> operands,
880     SmallVectorImpl<RegionSuccessor> &regions) {
881   // We always branch to the join region.
882   if (index.hasValue()) {
883     if (index.getValue() < 2)
884       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
885     else
886       regions.push_back(RegionSuccessor(getResults()));
887     return;
888   }
889 
890   // The then and else regions are the entry regions of this op.
891   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
892   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
893 }
894 
895 #include "TestOpEnums.cpp.inc"
896 #include "TestOpInterfaces.cpp.inc"
897 #include "TestOpStructs.cpp.inc"
898 #include "TestTypeInterfaces.cpp.inc"
899 
900 #define GET_OP_CLASSES
901 #include "TestOps.cpp.inc"
902