1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/DLTI/DLTI.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/FoldUtils.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "llvm/ADT/StringSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::test;
25 
26 void mlir::test::registerTestDialect(DialectRegistry &registry) {
27   registry.insert<TestDialect>();
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // TestDialect Interfaces
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 
36 // Test support for interacting with the AsmPrinter.
37 struct TestOpAsmInterface : public OpAsmDialectInterface {
38   using OpAsmDialectInterface::OpAsmDialectInterface;
39 
40   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
41     StringAttr strAttr = attr.dyn_cast<StringAttr>();
42     if (!strAttr)
43       return failure();
44 
45     // Check the contents of the string attribute to see what the test alias
46     // should be named.
47     Optional<StringRef> aliasName =
48         StringSwitch<Optional<StringRef>>(strAttr.getValue())
49             .Case("alias_test:dot_in_name", StringRef("test.alias"))
50             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
51             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
52             .Case("alias_test:sanitize_conflict_a",
53                   StringRef("test_alias_conflict0"))
54             .Case("alias_test:sanitize_conflict_b",
55                   StringRef("test_alias_conflict0_"))
56             .Default(llvm::None);
57     if (!aliasName)
58       return failure();
59 
60     os << *aliasName;
61     return success();
62   }
63 
64   void getAsmResultNames(Operation *op,
65                          OpAsmSetValueNameFn setNameFn) const final {
66     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
67       setNameFn(asmOp, "result");
68   }
69 
70   void getAsmBlockArgumentNames(Block *block,
71                                 OpAsmSetValueNameFn setNameFn) const final {
72     auto op = block->getParentOp();
73     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
74     if (!arrayAttr)
75       return;
76     auto args = block->getArguments();
77     auto e = std::min(arrayAttr.size(), args.size());
78     for (unsigned i = 0; i < e; ++i) {
79       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
80         setNameFn(args[i], strAttr.getValue());
81     }
82   }
83 };
84 
85 struct TestDialectFoldInterface : public DialectFoldInterface {
86   using DialectFoldInterface::DialectFoldInterface;
87 
88   /// Registered hook to check if the given region, which is attached to an
89   /// operation that is *not* isolated from above, should be used when
90   /// materializing constants.
91   bool shouldMaterializeInto(Region *region) const final {
92     // If this is a one region operation, then insert into it.
93     return isa<OneRegionOp>(region->getParentOp());
94   }
95 };
96 
97 /// This class defines the interface for handling inlining with standard
98 /// operations.
99 struct TestInlinerInterface : public DialectInlinerInterface {
100   using DialectInlinerInterface::DialectInlinerInterface;
101 
102   //===--------------------------------------------------------------------===//
103   // Analysis Hooks
104   //===--------------------------------------------------------------------===//
105 
106   bool isLegalToInline(Operation *call, Operation *callable,
107                        bool wouldBeCloned) const final {
108     // Don't allow inlining calls that are marked `noinline`.
109     return !call->hasAttr("noinline");
110   }
111   bool isLegalToInline(Region *, Region *, bool,
112                        BlockAndValueMapping &) const final {
113     // Inlining into test dialect regions is legal.
114     return true;
115   }
116   bool isLegalToInline(Operation *, Region *, bool,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   bool shouldAnalyzeRecursively(Operation *op) const final {
122     // Analyze recursively if this is not a functional region operation, it
123     // froms a separate functional scope.
124     return !isa<FunctionalRegionOp>(op);
125   }
126 
127   //===--------------------------------------------------------------------===//
128   // Transformation Hooks
129   //===--------------------------------------------------------------------===//
130 
131   /// Handle the given inlined terminator by replacing it with a new operation
132   /// as necessary.
133   void handleTerminator(Operation *op,
134                         ArrayRef<Value> valuesToRepl) const final {
135     // Only handle "test.return" here.
136     auto returnOp = dyn_cast<TestReturnOp>(op);
137     if (!returnOp)
138       return;
139 
140     // Replace the values directly with the return operands.
141     assert(returnOp.getNumOperands() == valuesToRepl.size());
142     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
143       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
144   }
145 
146   /// Attempt to materialize a conversion for a type mismatch between a call
147   /// from this dialect, and a callable region. This method should generate an
148   /// operation that takes 'input' as the only operand, and produces a single
149   /// result of 'resultType'. If a conversion can not be generated, nullptr
150   /// should be returned.
151   Operation *materializeCallConversion(OpBuilder &builder, Value input,
152                                        Type resultType,
153                                        Location conversionLoc) const final {
154     // Only allow conversion for i16/i32 types.
155     if (!(resultType.isSignlessInteger(16) ||
156           resultType.isSignlessInteger(32)) ||
157         !(input.getType().isSignlessInteger(16) ||
158           input.getType().isSignlessInteger(32)))
159       return nullptr;
160     return builder.create<TestCastOp>(conversionLoc, resultType, input);
161   }
162 };
163 } // end anonymous namespace
164 
165 //===----------------------------------------------------------------------===//
166 // TestDialect
167 //===----------------------------------------------------------------------===//
168 
169 void TestDialect::initialize() {
170   registerAttributes();
171   registerTypes();
172   addOperations<
173 #define GET_OP_LIST
174 #include "TestOps.cpp.inc"
175       >();
176   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
177                 TestInlinerInterface>();
178   allowUnknownOperations();
179 }
180 
181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
182                                             Type type, Location loc) {
183   return builder.create<TestOpConstant>(loc, type, value);
184 }
185 
186 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
187                                                     NamedAttribute namedAttr) {
188   if (namedAttr.first == "test.invalid_attr")
189     return op->emitError() << "invalid to use 'test.invalid_attr'";
190   return success();
191 }
192 
193 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
194                                                     unsigned regionIndex,
195                                                     unsigned argIndex,
196                                                     NamedAttribute namedAttr) {
197   if (namedAttr.first == "test.invalid_attr")
198     return op->emitError() << "invalid to use 'test.invalid_attr'";
199   return success();
200 }
201 
202 LogicalResult
203 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
204                                          unsigned resultIndex,
205                                          NamedAttribute namedAttr) {
206   if (namedAttr.first == "test.invalid_attr")
207     return op->emitError() << "invalid to use 'test.invalid_attr'";
208   return success();
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // TestBranchOp
213 //===----------------------------------------------------------------------===//
214 
215 Optional<MutableOperandRange>
216 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
217   assert(index == 0 && "invalid successor index");
218   return targetOperandsMutable();
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // TestFoldToCallOp
223 //===----------------------------------------------------------------------===//
224 
225 namespace {
226 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
227   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
228 
229   LogicalResult matchAndRewrite(FoldToCallOp op,
230                                 PatternRewriter &rewriter) const override {
231     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
232                                         ValueRange());
233     return success();
234   }
235 };
236 } // end anonymous namespace
237 
238 void FoldToCallOp::getCanonicalizationPatterns(
239     OwningRewritePatternList &results, MLIRContext *context) {
240   results.insert<FoldToCallOpPattern>(context);
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // Test Format* operations
245 //===----------------------------------------------------------------------===//
246 
247 //===----------------------------------------------------------------------===//
248 // Parsing
249 
250 static ParseResult parseCustomDirectiveOperands(
251     OpAsmParser &parser, OpAsmParser::OperandType &operand,
252     Optional<OpAsmParser::OperandType> &optOperand,
253     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
254   if (parser.parseOperand(operand))
255     return failure();
256   if (succeeded(parser.parseOptionalComma())) {
257     optOperand.emplace();
258     if (parser.parseOperand(*optOperand))
259       return failure();
260   }
261   if (parser.parseArrow() || parser.parseLParen() ||
262       parser.parseOperandList(varOperands) || parser.parseRParen())
263     return failure();
264   return success();
265 }
266 static ParseResult
267 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
268                             Type &optOperandType,
269                             SmallVectorImpl<Type> &varOperandTypes) {
270   if (parser.parseColon())
271     return failure();
272 
273   if (parser.parseType(operandType))
274     return failure();
275   if (succeeded(parser.parseOptionalComma())) {
276     if (parser.parseType(optOperandType))
277       return failure();
278   }
279   if (parser.parseArrow() || parser.parseLParen() ||
280       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
281     return failure();
282   return success();
283 }
284 static ParseResult
285 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
286                                  Type optOperandType,
287                                  const SmallVectorImpl<Type> &varOperandTypes) {
288   if (parser.parseKeyword("type_refs_capture"))
289     return failure();
290 
291   Type operandType2, optOperandType2;
292   SmallVector<Type, 1> varOperandTypes2;
293   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
294                                   varOperandTypes2))
295     return failure();
296 
297   if (operandType != operandType2 || optOperandType != optOperandType2 ||
298       varOperandTypes != varOperandTypes2)
299     return failure();
300 
301   return success();
302 }
303 static ParseResult parseCustomDirectiveOperandsAndTypes(
304     OpAsmParser &parser, OpAsmParser::OperandType &operand,
305     Optional<OpAsmParser::OperandType> &optOperand,
306     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
307     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
308   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
309       parseCustomDirectiveResults(parser, operandType, optOperandType,
310                                   varOperandTypes))
311     return failure();
312   return success();
313 }
314 static ParseResult parseCustomDirectiveRegions(
315     OpAsmParser &parser, Region &region,
316     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
317   if (parser.parseRegion(region))
318     return failure();
319   if (failed(parser.parseOptionalComma()))
320     return success();
321   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
322   if (parser.parseRegion(*varRegion))
323     return failure();
324   varRegions.emplace_back(std::move(varRegion));
325   return success();
326 }
327 static ParseResult
328 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
329                                SmallVectorImpl<Block *> &varSuccessors) {
330   if (parser.parseSuccessor(successor))
331     return failure();
332   if (failed(parser.parseOptionalComma()))
333     return success();
334   Block *varSuccessor;
335   if (parser.parseSuccessor(varSuccessor))
336     return failure();
337   varSuccessors.append(2, varSuccessor);
338   return success();
339 }
340 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
341                                                   IntegerAttr &attr,
342                                                   IntegerAttr &optAttr) {
343   if (parser.parseAttribute(attr))
344     return failure();
345   if (succeeded(parser.parseOptionalComma())) {
346     if (parser.parseAttribute(optAttr))
347       return failure();
348   }
349   return success();
350 }
351 
352 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
353                                                 NamedAttrList &attrs) {
354   return parser.parseOptionalAttrDict(attrs);
355 }
356 static ParseResult parseCustomDirectiveOptionalOperandRef(
357     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
358   int64_t operandCount = 0;
359   if (parser.parseInteger(operandCount))
360     return failure();
361   bool expectedOptionalOperand = operandCount == 0;
362   return success(expectedOptionalOperand != optOperand.hasValue());
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // Printing
367 
368 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
369                                          Value operand, Value optOperand,
370                                          OperandRange varOperands) {
371   printer << operand;
372   if (optOperand)
373     printer << ", " << optOperand;
374   printer << " -> (" << varOperands << ")";
375 }
376 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
377                                         Type operandType, Type optOperandType,
378                                         TypeRange varOperandTypes) {
379   printer << " : " << operandType;
380   if (optOperandType)
381     printer << ", " << optOperandType;
382   printer << " -> (" << varOperandTypes << ")";
383 }
384 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
385                                              Operation *op, Type operandType,
386                                              Type optOperandType,
387                                              TypeRange varOperandTypes) {
388   printer << " type_refs_capture ";
389   printCustomDirectiveResults(printer, op, operandType, optOperandType,
390                               varOperandTypes);
391 }
392 static void printCustomDirectiveOperandsAndTypes(
393     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
394     OperandRange varOperands, Type operandType, Type optOperandType,
395     TypeRange varOperandTypes) {
396   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
397   printCustomDirectiveResults(printer, op, operandType, optOperandType,
398                               varOperandTypes);
399 }
400 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
401                                         Region &region,
402                                         MutableArrayRef<Region> varRegions) {
403   printer.printRegion(region);
404   if (!varRegions.empty()) {
405     printer << ", ";
406     for (Region &region : varRegions)
407       printer.printRegion(region);
408   }
409 }
410 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
411                                            Block *successor,
412                                            SuccessorRange varSuccessors) {
413   printer << successor;
414   if (!varSuccessors.empty())
415     printer << ", " << varSuccessors.front();
416 }
417 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
418                                            Attribute attribute,
419                                            Attribute optAttribute) {
420   printer << attribute;
421   if (optAttribute)
422     printer << ", " << optAttribute;
423 }
424 
425 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
426                                          DictionaryAttr attrs) {
427   printer.printOptionalAttrDict(attrs.getValue());
428 }
429 
430 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
431                                                    Operation *op,
432                                                    Value optOperand) {
433   printer << (optOperand ? "1" : "0");
434 }
435 
436 //===----------------------------------------------------------------------===//
437 // Test IsolatedRegionOp - parse passthrough region arguments.
438 //===----------------------------------------------------------------------===//
439 
440 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
441                                          OperationState &result) {
442   OpAsmParser::OperandType argInfo;
443   Type argType = parser.getBuilder().getIndexType();
444 
445   // Parse the input operand.
446   if (parser.parseOperand(argInfo) ||
447       parser.resolveOperand(argInfo, argType, result.operands))
448     return failure();
449 
450   // Parse the body region, and reuse the operand info as the argument info.
451   Region *body = result.addRegion();
452   return parser.parseRegion(*body, argInfo, argType,
453                             /*enableNameShadowing=*/true);
454 }
455 
456 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
457   p << "test.isolated_region ";
458   p.printOperand(op.getOperand());
459   p.shadowRegionArgs(op.region(), op.getOperand());
460   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // Test SSACFGRegionOp
465 //===----------------------------------------------------------------------===//
466 
467 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
468   return RegionKind::SSACFG;
469 }
470 
471 //===----------------------------------------------------------------------===//
472 // Test GraphRegionOp
473 //===----------------------------------------------------------------------===//
474 
475 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
476                                       OperationState &result) {
477   // Parse the body region, and reuse the operand info as the argument info.
478   Region *body = result.addRegion();
479   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
480 }
481 
482 static void print(OpAsmPrinter &p, GraphRegionOp op) {
483   p << "test.graph_region ";
484   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
485 }
486 
487 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
488   return RegionKind::Graph;
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // Test AffineScopeOp
493 //===----------------------------------------------------------------------===//
494 
495 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
496                                       OperationState &result) {
497   // Parse the body region, and reuse the operand info as the argument info.
498   Region *body = result.addRegion();
499   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
500 }
501 
502 static void print(OpAsmPrinter &p, AffineScopeOp op) {
503   p << "test.affine_scope ";
504   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
505 }
506 
507 //===----------------------------------------------------------------------===//
508 // Test parser.
509 //===----------------------------------------------------------------------===//
510 
511 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
512                                               OperationState &result) {
513   if (parser.parseOptionalColon())
514     return success();
515   uint64_t numResults;
516   if (parser.parseInteger(numResults))
517     return failure();
518 
519   IndexType type = parser.getBuilder().getIndexType();
520   for (unsigned i = 0; i < numResults; ++i)
521     result.addTypes(type);
522   return success();
523 }
524 
525 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
526   p << ParseIntegerLiteralOp::getOperationName();
527   if (unsigned numResults = op->getNumResults())
528     p << " : " << numResults;
529 }
530 
531 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
532                                               OperationState &result) {
533   StringRef keyword;
534   if (parser.parseKeyword(&keyword))
535     return failure();
536   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
537   return success();
538 }
539 
540 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
541   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
542 }
543 
544 //===----------------------------------------------------------------------===//
545 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
546 
547 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
548                                          OperationState &result) {
549   if (parser.parseKeyword("wraps"))
550     return failure();
551 
552   // Parse the wrapped op in a region
553   Region &body = *result.addRegion();
554   body.push_back(new Block);
555   Block &block = body.back();
556   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
557   if (!wrapped_op)
558     return failure();
559 
560   // Create a return terminator in the inner region, pass as operand to the
561   // terminator the returned values from the wrapped operation.
562   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
563   OpBuilder builder(parser.getBuilder().getContext());
564   builder.setInsertionPointToEnd(&block);
565   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
566 
567   // Get the results type for the wrapping op from the terminator operands.
568   Operation &return_op = body.back().back();
569   result.types.append(return_op.operand_type_begin(),
570                       return_op.operand_type_end());
571 
572   // Use the location of the wrapped op for the "test.wrapping_region" op.
573   result.location = wrapped_op->getLoc();
574 
575   return success();
576 }
577 
578 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
579   p << op.getOperationName() << " wraps ";
580   p.printGenericOp(&op.region().front().front());
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // Test PolyForOp - parse list of region arguments.
585 //===----------------------------------------------------------------------===//
586 
587 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
588   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
589   // Parse list of region arguments without a delimiter.
590   if (parser.parseRegionArgumentList(ivsInfo))
591     return failure();
592 
593   // Parse the body region.
594   Region *body = result.addRegion();
595   auto &builder = parser.getBuilder();
596   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
597   return parser.parseRegion(*body, ivsInfo, argTypes);
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // Test removing op with inner ops.
602 //===----------------------------------------------------------------------===//
603 
604 namespace {
605 struct TestRemoveOpWithInnerOps
606     : public OpRewritePattern<TestOpWithRegionPattern> {
607   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
608 
609   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
610                                 PatternRewriter &rewriter) const override {
611     rewriter.eraseOp(op);
612     return success();
613   }
614 };
615 } // end anonymous namespace
616 
617 void TestOpWithRegionPattern::getCanonicalizationPatterns(
618     OwningRewritePatternList &results, MLIRContext *context) {
619   results.insert<TestRemoveOpWithInnerOps>(context);
620 }
621 
622 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
623   return operand();
624 }
625 
626 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
627   return getValue();
628 }
629 
630 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
631     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
632   for (Value input : this->operands()) {
633     results.push_back(input);
634   }
635   return success();
636 }
637 
638 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
639   assert(operands.size() == 1);
640   if (operands.front()) {
641     (*this)->setAttr("attr", operands.front());
642     return getResult();
643   }
644   return {};
645 }
646 
647 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
648   return getOperand();
649 }
650 
651 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
652     MLIRContext *, Optional<Location> location, ValueRange operands,
653     DictionaryAttr attributes, RegionRange regions,
654     SmallVectorImpl<Type> &inferredReturnTypes) {
655   if (operands[0].getType() != operands[1].getType()) {
656     return emitOptionalError(location, "operand type mismatch ",
657                              operands[0].getType(), " vs ",
658                              operands[1].getType());
659   }
660   inferredReturnTypes.assign({operands[0].getType()});
661   return success();
662 }
663 
664 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
665     MLIRContext *context, Optional<Location> location, ValueRange operands,
666     DictionaryAttr attributes, RegionRange regions,
667     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
668   // Create return type consisting of the last element of the first operand.
669   auto operandType = *operands.getTypes().begin();
670   auto sval = operandType.dyn_cast<ShapedType>();
671   if (!sval) {
672     return emitOptionalError(location, "only shaped type operands allowed");
673   }
674   int64_t dim =
675       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
676   auto type = IntegerType::get(context, 17);
677   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
678   return success();
679 }
680 
681 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
682     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
683   shapes = SmallVector<Value, 1>{
684       builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)};
685   return success();
686 }
687 
688 //===----------------------------------------------------------------------===//
689 // Test SideEffect interfaces
690 //===----------------------------------------------------------------------===//
691 
692 namespace {
693 /// A test resource for side effects.
694 struct TestResource : public SideEffects::Resource::Base<TestResource> {
695   StringRef getName() final { return "<Test>"; }
696 };
697 } // end anonymous namespace
698 
699 void SideEffectOp::getEffects(
700     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
701   // Check for an effects attribute on the op instance.
702   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
703   if (!effectsAttr)
704     return;
705 
706   // If there is one, it is an array of dictionary attributes that hold
707   // information on the effects of this operation.
708   for (Attribute element : effectsAttr) {
709     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
710 
711     // Get the specific memory effect.
712     MemoryEffects::Effect *effect =
713         StringSwitch<MemoryEffects::Effect *>(
714             effectElement.get("effect").cast<StringAttr>().getValue())
715             .Case("allocate", MemoryEffects::Allocate::get())
716             .Case("free", MemoryEffects::Free::get())
717             .Case("read", MemoryEffects::Read::get())
718             .Case("write", MemoryEffects::Write::get());
719 
720     // Check for a non-default resource to use.
721     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
722     if (effectElement.get("test_resource"))
723       resource = TestResource::get();
724 
725     // Check for a result to affect.
726     if (effectElement.get("on_result"))
727       effects.emplace_back(effect, getResult(), resource);
728     else if (Attribute ref = effectElement.get("on_reference"))
729       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
730     else
731       effects.emplace_back(effect, resource);
732   }
733 }
734 
735 void SideEffectOp::getEffects(
736     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
737   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
738   if (!effectsAttr)
739     return;
740 
741   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
742 }
743 
744 //===----------------------------------------------------------------------===//
745 // StringAttrPrettyNameOp
746 //===----------------------------------------------------------------------===//
747 
748 // This op has fancy handling of its SSA result name.
749 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
750                                                OperationState &result) {
751   // Add the result types.
752   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
753     result.addTypes(parser.getBuilder().getIntegerType(32));
754 
755   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
756     return failure();
757 
758   // If the attribute dictionary contains no 'names' attribute, infer it from
759   // the SSA name (if specified).
760   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
761     return attr.first == "names";
762   });
763 
764   // If there was no name specified, check to see if there was a useful name
765   // specified in the asm file.
766   if (hadNames || parser.getNumResults() == 0)
767     return success();
768 
769   SmallVector<StringRef, 4> names;
770   auto *context = result.getContext();
771 
772   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
773     auto resultName = parser.getResultName(i);
774     StringRef nameStr;
775     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
776       nameStr = resultName.first;
777 
778     names.push_back(nameStr);
779   }
780 
781   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
782   result.attributes.push_back({Identifier::get("names", context), namesAttr});
783   return success();
784 }
785 
786 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
787   p << "test.string_attr_pretty_name";
788 
789   // Note that we only need to print the "name" attribute if the asmprinter
790   // result name disagrees with it.  This can happen in strange cases, e.g.
791   // when there are conflicts.
792   bool namesDisagree = op.names().size() != op.getNumResults();
793 
794   SmallString<32> resultNameStr;
795   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
796     resultNameStr.clear();
797     llvm::raw_svector_ostream tmpStream(resultNameStr);
798     p.printOperand(op.getResult(i), tmpStream);
799 
800     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
801     if (!expectedName ||
802         tmpStream.str().drop_front() != expectedName.getValue()) {
803       namesDisagree = true;
804     }
805   }
806 
807   if (namesDisagree)
808     p.printOptionalAttrDictWithKeyword(op->getAttrs());
809   else
810     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
811 }
812 
813 // We set the SSA name in the asm syntax to the contents of the name
814 // attribute.
815 void StringAttrPrettyNameOp::getAsmResultNames(
816     function_ref<void(Value, StringRef)> setNameFn) {
817 
818   auto value = names();
819   for (size_t i = 0, e = value.size(); i != e; ++i)
820     if (auto str = value[i].dyn_cast<StringAttr>())
821       if (!str.getValue().empty())
822         setNameFn(getResult(i), str.getValue());
823 }
824 
825 //===----------------------------------------------------------------------===//
826 // RegionIfOp
827 //===----------------------------------------------------------------------===//
828 
829 static void print(OpAsmPrinter &p, RegionIfOp op) {
830   p << RegionIfOp::getOperationName() << " ";
831   p.printOperands(op.getOperands());
832   p << ": " << op.getOperandTypes();
833   p.printArrowTypeList(op.getResultTypes());
834   p << " then";
835   p.printRegion(op.thenRegion(),
836                 /*printEntryBlockArgs=*/true,
837                 /*printBlockTerminators=*/true);
838   p << " else";
839   p.printRegion(op.elseRegion(),
840                 /*printEntryBlockArgs=*/true,
841                 /*printBlockTerminators=*/true);
842   p << " join";
843   p.printRegion(op.joinRegion(),
844                 /*printEntryBlockArgs=*/true,
845                 /*printBlockTerminators=*/true);
846 }
847 
848 static ParseResult parseRegionIfOp(OpAsmParser &parser,
849                                    OperationState &result) {
850   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
851   SmallVector<Type, 2> operandTypes;
852 
853   result.regions.reserve(3);
854   Region *thenRegion = result.addRegion();
855   Region *elseRegion = result.addRegion();
856   Region *joinRegion = result.addRegion();
857 
858   // Parse operand, type and arrow type lists.
859   if (parser.parseOperandList(operandInfos) ||
860       parser.parseColonTypeList(operandTypes) ||
861       parser.parseArrowTypeList(result.types))
862     return failure();
863 
864   // Parse all attached regions.
865   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
866       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
867       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
868     return failure();
869 
870   return parser.resolveOperands(operandInfos, operandTypes,
871                                 parser.getCurrentLocation(), result.operands);
872 }
873 
874 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
875   assert(index < 2 && "invalid region index");
876   return getOperands();
877 }
878 
879 void RegionIfOp::getSuccessorRegions(
880     Optional<unsigned> index, ArrayRef<Attribute> operands,
881     SmallVectorImpl<RegionSuccessor> &regions) {
882   // We always branch to the join region.
883   if (index.hasValue()) {
884     if (index.getValue() < 2)
885       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
886     else
887       regions.push_back(RegionSuccessor(getResults()));
888     return;
889   }
890 
891   // The then and else regions are the entry regions of this op.
892   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
893   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
894 }
895 
896 #include "TestOpEnums.cpp.inc"
897 #include "TestOpInterfaces.cpp.inc"
898 #include "TestOpStructs.cpp.inc"
899 #include "TestTypeInterfaces.cpp.inc"
900 
901 #define GET_OP_CLASSES
902 #include "TestOps.cpp.inc"
903