1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestAttributes.h"
11 #include "TestTypes.h"
12 #include "mlir/Dialect/DLTI/DLTI.h"
13 #include "mlir/Dialect/MemRef/IR/MemRef.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "mlir/Transforms/FoldUtils.h"
20 #include "mlir/Transforms/InliningUtils.h"
21 #include "llvm/ADT/StringSwitch.h"
22 
23 using namespace mlir;
24 using namespace mlir::test;
25 
26 void mlir::test::registerTestDialect(DialectRegistry &registry) {
27   registry.insert<TestDialect>();
28 }
29 
30 //===----------------------------------------------------------------------===//
31 // TestDialect Interfaces
32 //===----------------------------------------------------------------------===//
33 
34 namespace {
35 
36 // Test support for interacting with the AsmPrinter.
37 struct TestOpAsmInterface : public OpAsmDialectInterface {
38   using OpAsmDialectInterface::OpAsmDialectInterface;
39 
40   LogicalResult getAlias(Attribute attr, raw_ostream &os) const final {
41     StringAttr strAttr = attr.dyn_cast<StringAttr>();
42     if (!strAttr)
43       return failure();
44 
45     // Check the contents of the string attribute to see what the test alias
46     // should be named.
47     Optional<StringRef> aliasName =
48         StringSwitch<Optional<StringRef>>(strAttr.getValue())
49             .Case("alias_test:dot_in_name", StringRef("test.alias"))
50             .Case("alias_test:trailing_digit", StringRef("test_alias0"))
51             .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
52             .Case("alias_test:sanitize_conflict_a",
53                   StringRef("test_alias_conflict0"))
54             .Case("alias_test:sanitize_conflict_b",
55                   StringRef("test_alias_conflict0_"))
56             .Default(llvm::None);
57     if (!aliasName)
58       return failure();
59 
60     os << *aliasName;
61     return success();
62   }
63 
64   void getAsmResultNames(Operation *op,
65                          OpAsmSetValueNameFn setNameFn) const final {
66     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
67       setNameFn(asmOp, "result");
68   }
69 
70   void getAsmBlockArgumentNames(Block *block,
71                                 OpAsmSetValueNameFn setNameFn) const final {
72     auto op = block->getParentOp();
73     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
74     if (!arrayAttr)
75       return;
76     auto args = block->getArguments();
77     auto e = std::min(arrayAttr.size(), args.size());
78     for (unsigned i = 0; i < e; ++i) {
79       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
80         setNameFn(args[i], strAttr.getValue());
81     }
82   }
83 };
84 
85 struct TestDialectFoldInterface : public DialectFoldInterface {
86   using DialectFoldInterface::DialectFoldInterface;
87 
88   /// Registered hook to check if the given region, which is attached to an
89   /// operation that is *not* isolated from above, should be used when
90   /// materializing constants.
91   bool shouldMaterializeInto(Region *region) const final {
92     // If this is a one region operation, then insert into it.
93     return isa<OneRegionOp>(region->getParentOp());
94   }
95 };
96 
97 /// This class defines the interface for handling inlining with standard
98 /// operations.
99 struct TestInlinerInterface : public DialectInlinerInterface {
100   using DialectInlinerInterface::DialectInlinerInterface;
101 
102   //===--------------------------------------------------------------------===//
103   // Analysis Hooks
104   //===--------------------------------------------------------------------===//
105 
106   bool isLegalToInline(Operation *call, Operation *callable,
107                        bool wouldBeCloned) const final {
108     // Don't allow inlining calls that are marked `noinline`.
109     return !call->hasAttr("noinline");
110   }
111   bool isLegalToInline(Region *, Region *, bool,
112                        BlockAndValueMapping &) const final {
113     // Inlining into test dialect regions is legal.
114     return true;
115   }
116   bool isLegalToInline(Operation *, Region *, bool,
117                        BlockAndValueMapping &) const final {
118     return true;
119   }
120 
121   bool shouldAnalyzeRecursively(Operation *op) const final {
122     // Analyze recursively if this is not a functional region operation, it
123     // froms a separate functional scope.
124     return !isa<FunctionalRegionOp>(op);
125   }
126 
127   //===--------------------------------------------------------------------===//
128   // Transformation Hooks
129   //===--------------------------------------------------------------------===//
130 
131   /// Handle the given inlined terminator by replacing it with a new operation
132   /// as necessary.
133   void handleTerminator(Operation *op,
134                         ArrayRef<Value> valuesToRepl) const final {
135     // Only handle "test.return" here.
136     auto returnOp = dyn_cast<TestReturnOp>(op);
137     if (!returnOp)
138       return;
139 
140     // Replace the values directly with the return operands.
141     assert(returnOp.getNumOperands() == valuesToRepl.size());
142     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
143       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
144   }
145 
146   /// Attempt to materialize a conversion for a type mismatch between a call
147   /// from this dialect, and a callable region. This method should generate an
148   /// operation that takes 'input' as the only operand, and produces a single
149   /// result of 'resultType'. If a conversion can not be generated, nullptr
150   /// should be returned.
151   Operation *materializeCallConversion(OpBuilder &builder, Value input,
152                                        Type resultType,
153                                        Location conversionLoc) const final {
154     // Only allow conversion for i16/i32 types.
155     if (!(resultType.isSignlessInteger(16) ||
156           resultType.isSignlessInteger(32)) ||
157         !(input.getType().isSignlessInteger(16) ||
158           input.getType().isSignlessInteger(32)))
159       return nullptr;
160     return builder.create<TestCastOp>(conversionLoc, resultType, input);
161   }
162 };
163 } // end anonymous namespace
164 
165 //===----------------------------------------------------------------------===//
166 // TestDialect
167 //===----------------------------------------------------------------------===//
168 
169 void TestDialect::initialize() {
170   registerAttributes();
171   registerTypes();
172   addOperations<
173 #define GET_OP_LIST
174 #include "TestOps.cpp.inc"
175       >();
176   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
177                 TestInlinerInterface>();
178   allowUnknownOperations();
179 }
180 
181 Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
182                                             Type type, Location loc) {
183   return builder.create<TestOpConstant>(loc, type, value);
184 }
185 
186 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
187                                                     NamedAttribute namedAttr) {
188   if (namedAttr.first == "test.invalid_attr")
189     return op->emitError() << "invalid to use 'test.invalid_attr'";
190   return success();
191 }
192 
193 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
194                                                     unsigned regionIndex,
195                                                     unsigned argIndex,
196                                                     NamedAttribute namedAttr) {
197   if (namedAttr.first == "test.invalid_attr")
198     return op->emitError() << "invalid to use 'test.invalid_attr'";
199   return success();
200 }
201 
202 LogicalResult
203 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
204                                          unsigned resultIndex,
205                                          NamedAttribute namedAttr) {
206   if (namedAttr.first == "test.invalid_attr")
207     return op->emitError() << "invalid to use 'test.invalid_attr'";
208   return success();
209 }
210 
211 Optional<Dialect::ParseOpHook>
212 TestDialect::getParseOperationHook(StringRef opName) const {
213   if (opName == "test.dialect_custom_printer") {
214     return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
215       return parser.parseKeyword("custom_format");
216     }};
217   }
218   return None;
219 }
220 
221 LogicalResult TestDialect::printOperation(Operation *op,
222                                           OpAsmPrinter &printer) const {
223   StringRef opName = op->getName().getStringRef();
224   if (opName == "test.dialect_custom_printer") {
225     printer.getStream() << opName << " custom_format";
226     return success();
227   }
228   return failure();
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // TestBranchOp
233 //===----------------------------------------------------------------------===//
234 
235 Optional<MutableOperandRange>
236 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
237   assert(index == 0 && "invalid successor index");
238   return targetOperandsMutable();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // TestFoldToCallOp
243 //===----------------------------------------------------------------------===//
244 
245 namespace {
246 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
247   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
248 
249   LogicalResult matchAndRewrite(FoldToCallOp op,
250                                 PatternRewriter &rewriter) const override {
251     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
252                                         ValueRange());
253     return success();
254   }
255 };
256 } // end anonymous namespace
257 
258 void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
259                                                MLIRContext *context) {
260   results.add<FoldToCallOpPattern>(context);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // Test Format* operations
265 //===----------------------------------------------------------------------===//
266 
267 //===----------------------------------------------------------------------===//
268 // Parsing
269 
270 static ParseResult parseCustomDirectiveOperands(
271     OpAsmParser &parser, OpAsmParser::OperandType &operand,
272     Optional<OpAsmParser::OperandType> &optOperand,
273     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
274   if (parser.parseOperand(operand))
275     return failure();
276   if (succeeded(parser.parseOptionalComma())) {
277     optOperand.emplace();
278     if (parser.parseOperand(*optOperand))
279       return failure();
280   }
281   if (parser.parseArrow() || parser.parseLParen() ||
282       parser.parseOperandList(varOperands) || parser.parseRParen())
283     return failure();
284   return success();
285 }
286 static ParseResult
287 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
288                             Type &optOperandType,
289                             SmallVectorImpl<Type> &varOperandTypes) {
290   if (parser.parseColon())
291     return failure();
292 
293   if (parser.parseType(operandType))
294     return failure();
295   if (succeeded(parser.parseOptionalComma())) {
296     if (parser.parseType(optOperandType))
297       return failure();
298   }
299   if (parser.parseArrow() || parser.parseLParen() ||
300       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
301     return failure();
302   return success();
303 }
304 static ParseResult
305 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
306                                  Type optOperandType,
307                                  const SmallVectorImpl<Type> &varOperandTypes) {
308   if (parser.parseKeyword("type_refs_capture"))
309     return failure();
310 
311   Type operandType2, optOperandType2;
312   SmallVector<Type, 1> varOperandTypes2;
313   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
314                                   varOperandTypes2))
315     return failure();
316 
317   if (operandType != operandType2 || optOperandType != optOperandType2 ||
318       varOperandTypes != varOperandTypes2)
319     return failure();
320 
321   return success();
322 }
323 static ParseResult parseCustomDirectiveOperandsAndTypes(
324     OpAsmParser &parser, OpAsmParser::OperandType &operand,
325     Optional<OpAsmParser::OperandType> &optOperand,
326     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
327     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
328   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
329       parseCustomDirectiveResults(parser, operandType, optOperandType,
330                                   varOperandTypes))
331     return failure();
332   return success();
333 }
334 static ParseResult parseCustomDirectiveRegions(
335     OpAsmParser &parser, Region &region,
336     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
337   if (parser.parseRegion(region))
338     return failure();
339   if (failed(parser.parseOptionalComma()))
340     return success();
341   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
342   if (parser.parseRegion(*varRegion))
343     return failure();
344   varRegions.emplace_back(std::move(varRegion));
345   return success();
346 }
347 static ParseResult
348 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
349                                SmallVectorImpl<Block *> &varSuccessors) {
350   if (parser.parseSuccessor(successor))
351     return failure();
352   if (failed(parser.parseOptionalComma()))
353     return success();
354   Block *varSuccessor;
355   if (parser.parseSuccessor(varSuccessor))
356     return failure();
357   varSuccessors.append(2, varSuccessor);
358   return success();
359 }
360 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
361                                                   IntegerAttr &attr,
362                                                   IntegerAttr &optAttr) {
363   if (parser.parseAttribute(attr))
364     return failure();
365   if (succeeded(parser.parseOptionalComma())) {
366     if (parser.parseAttribute(optAttr))
367       return failure();
368   }
369   return success();
370 }
371 
372 static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
373                                                 NamedAttrList &attrs) {
374   return parser.parseOptionalAttrDict(attrs);
375 }
376 static ParseResult parseCustomDirectiveOptionalOperandRef(
377     OpAsmParser &parser, Optional<OpAsmParser::OperandType> &optOperand) {
378   int64_t operandCount = 0;
379   if (parser.parseInteger(operandCount))
380     return failure();
381   bool expectedOptionalOperand = operandCount == 0;
382   return success(expectedOptionalOperand != optOperand.hasValue());
383 }
384 
385 //===----------------------------------------------------------------------===//
386 // Printing
387 
388 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
389                                          Value operand, Value optOperand,
390                                          OperandRange varOperands) {
391   printer << operand;
392   if (optOperand)
393     printer << ", " << optOperand;
394   printer << " -> (" << varOperands << ")";
395 }
396 static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
397                                         Type operandType, Type optOperandType,
398                                         TypeRange varOperandTypes) {
399   printer << " : " << operandType;
400   if (optOperandType)
401     printer << ", " << optOperandType;
402   printer << " -> (" << varOperandTypes << ")";
403 }
404 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
405                                              Operation *op, Type operandType,
406                                              Type optOperandType,
407                                              TypeRange varOperandTypes) {
408   printer << " type_refs_capture ";
409   printCustomDirectiveResults(printer, op, operandType, optOperandType,
410                               varOperandTypes);
411 }
412 static void printCustomDirectiveOperandsAndTypes(
413     OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
414     OperandRange varOperands, Type operandType, Type optOperandType,
415     TypeRange varOperandTypes) {
416   printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
417   printCustomDirectiveResults(printer, op, operandType, optOperandType,
418                               varOperandTypes);
419 }
420 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
421                                         Region &region,
422                                         MutableArrayRef<Region> varRegions) {
423   printer.printRegion(region);
424   if (!varRegions.empty()) {
425     printer << ", ";
426     for (Region &region : varRegions)
427       printer.printRegion(region);
428   }
429 }
430 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
431                                            Block *successor,
432                                            SuccessorRange varSuccessors) {
433   printer << successor;
434   if (!varSuccessors.empty())
435     printer << ", " << varSuccessors.front();
436 }
437 static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
438                                            Attribute attribute,
439                                            Attribute optAttribute) {
440   printer << attribute;
441   if (optAttribute)
442     printer << ", " << optAttribute;
443 }
444 
445 static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
446                                          DictionaryAttr attrs) {
447   printer.printOptionalAttrDict(attrs.getValue());
448 }
449 
450 static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
451                                                    Operation *op,
452                                                    Value optOperand) {
453   printer << (optOperand ? "1" : "0");
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // Test IsolatedRegionOp - parse passthrough region arguments.
458 //===----------------------------------------------------------------------===//
459 
460 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
461                                          OperationState &result) {
462   OpAsmParser::OperandType argInfo;
463   Type argType = parser.getBuilder().getIndexType();
464 
465   // Parse the input operand.
466   if (parser.parseOperand(argInfo) ||
467       parser.resolveOperand(argInfo, argType, result.operands))
468     return failure();
469 
470   // Parse the body region, and reuse the operand info as the argument info.
471   Region *body = result.addRegion();
472   return parser.parseRegion(*body, argInfo, argType,
473                             /*enableNameShadowing=*/true);
474 }
475 
476 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
477   p << "test.isolated_region ";
478   p.printOperand(op.getOperand());
479   p.shadowRegionArgs(op.region(), op.getOperand());
480   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // Test SSACFGRegionOp
485 //===----------------------------------------------------------------------===//
486 
487 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
488   return RegionKind::SSACFG;
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // Test GraphRegionOp
493 //===----------------------------------------------------------------------===//
494 
495 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
496                                       OperationState &result) {
497   // Parse the body region, and reuse the operand info as the argument info.
498   Region *body = result.addRegion();
499   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
500 }
501 
502 static void print(OpAsmPrinter &p, GraphRegionOp op) {
503   p << "test.graph_region ";
504   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
505 }
506 
507 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
508   return RegionKind::Graph;
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // Test AffineScopeOp
513 //===----------------------------------------------------------------------===//
514 
515 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
516                                       OperationState &result) {
517   // Parse the body region, and reuse the operand info as the argument info.
518   Region *body = result.addRegion();
519   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
520 }
521 
522 static void print(OpAsmPrinter &p, AffineScopeOp op) {
523   p << "test.affine_scope ";
524   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // Test parser.
529 //===----------------------------------------------------------------------===//
530 
531 static ParseResult parseParseIntegerLiteralOp(OpAsmParser &parser,
532                                               OperationState &result) {
533   if (parser.parseOptionalColon())
534     return success();
535   uint64_t numResults;
536   if (parser.parseInteger(numResults))
537     return failure();
538 
539   IndexType type = parser.getBuilder().getIndexType();
540   for (unsigned i = 0; i < numResults; ++i)
541     result.addTypes(type);
542   return success();
543 }
544 
545 static void print(OpAsmPrinter &p, ParseIntegerLiteralOp op) {
546   p << ParseIntegerLiteralOp::getOperationName();
547   if (unsigned numResults = op->getNumResults())
548     p << " : " << numResults;
549 }
550 
551 static ParseResult parseParseWrappedKeywordOp(OpAsmParser &parser,
552                                               OperationState &result) {
553   StringRef keyword;
554   if (parser.parseKeyword(&keyword))
555     return failure();
556   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
557   return success();
558 }
559 
560 static void print(OpAsmPrinter &p, ParseWrappedKeywordOp op) {
561   p << ParseWrappedKeywordOp::getOperationName() << " " << op.keyword();
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
566 
567 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
568                                          OperationState &result) {
569   if (parser.parseKeyword("wraps"))
570     return failure();
571 
572   // Parse the wrapped op in a region
573   Region &body = *result.addRegion();
574   body.push_back(new Block);
575   Block &block = body.back();
576   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
577   if (!wrapped_op)
578     return failure();
579 
580   // Create a return terminator in the inner region, pass as operand to the
581   // terminator the returned values from the wrapped operation.
582   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
583   OpBuilder builder(parser.getBuilder().getContext());
584   builder.setInsertionPointToEnd(&block);
585   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
586 
587   // Get the results type for the wrapping op from the terminator operands.
588   Operation &return_op = body.back().back();
589   result.types.append(return_op.operand_type_begin(),
590                       return_op.operand_type_end());
591 
592   // Use the location of the wrapped op for the "test.wrapping_region" op.
593   result.location = wrapped_op->getLoc();
594 
595   return success();
596 }
597 
598 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
599   p << op.getOperationName() << " wraps ";
600   p.printGenericOp(&op.region().front().front());
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // Test PolyForOp - parse list of region arguments.
605 //===----------------------------------------------------------------------===//
606 
607 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
608   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
609   // Parse list of region arguments without a delimiter.
610   if (parser.parseRegionArgumentList(ivsInfo))
611     return failure();
612 
613   // Parse the body region.
614   Region *body = result.addRegion();
615   auto &builder = parser.getBuilder();
616   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
617   return parser.parseRegion(*body, ivsInfo, argTypes);
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // Test removing op with inner ops.
622 //===----------------------------------------------------------------------===//
623 
624 namespace {
625 struct TestRemoveOpWithInnerOps
626     : public OpRewritePattern<TestOpWithRegionPattern> {
627   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
628 
629   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
630                                 PatternRewriter &rewriter) const override {
631     rewriter.eraseOp(op);
632     return success();
633   }
634 };
635 } // end anonymous namespace
636 
637 void TestOpWithRegionPattern::getCanonicalizationPatterns(
638     RewritePatternSet &results, MLIRContext *context) {
639   results.add<TestRemoveOpWithInnerOps>(context);
640 }
641 
642 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
643   return operand();
644 }
645 
646 OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
647   return getValue();
648 }
649 
650 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
651     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
652   for (Value input : this->operands()) {
653     results.push_back(input);
654   }
655   return success();
656 }
657 
658 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
659   assert(operands.size() == 1);
660   if (operands.front()) {
661     (*this)->setAttr("attr", operands.front());
662     return getResult();
663   }
664   return {};
665 }
666 
667 OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
668   return getOperand();
669 }
670 
671 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
672     MLIRContext *, Optional<Location> location, ValueRange operands,
673     DictionaryAttr attributes, RegionRange regions,
674     SmallVectorImpl<Type> &inferredReturnTypes) {
675   if (operands[0].getType() != operands[1].getType()) {
676     return emitOptionalError(location, "operand type mismatch ",
677                              operands[0].getType(), " vs ",
678                              operands[1].getType());
679   }
680   inferredReturnTypes.assign({operands[0].getType()});
681   return success();
682 }
683 
684 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
685     MLIRContext *context, Optional<Location> location, ValueRange operands,
686     DictionaryAttr attributes, RegionRange regions,
687     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
688   // Create return type consisting of the last element of the first operand.
689   auto operandType = *operands.getTypes().begin();
690   auto sval = operandType.dyn_cast<ShapedType>();
691   if (!sval) {
692     return emitOptionalError(location, "only shaped type operands allowed");
693   }
694   int64_t dim =
695       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
696   auto type = IntegerType::get(context, 17);
697   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
698   return success();
699 }
700 
701 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
702     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
703   shapes = SmallVector<Value, 1>{
704       builder.createOrFold<memref::DimOp>(getLoc(), getOperand(0), 0)};
705   return success();
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // Test SideEffect interfaces
710 //===----------------------------------------------------------------------===//
711 
712 namespace {
713 /// A test resource for side effects.
714 struct TestResource : public SideEffects::Resource::Base<TestResource> {
715   StringRef getName() final { return "<Test>"; }
716 };
717 } // end anonymous namespace
718 
719 void SideEffectOp::getEffects(
720     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
721   // Check for an effects attribute on the op instance.
722   ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
723   if (!effectsAttr)
724     return;
725 
726   // If there is one, it is an array of dictionary attributes that hold
727   // information on the effects of this operation.
728   for (Attribute element : effectsAttr) {
729     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
730 
731     // Get the specific memory effect.
732     MemoryEffects::Effect *effect =
733         StringSwitch<MemoryEffects::Effect *>(
734             effectElement.get("effect").cast<StringAttr>().getValue())
735             .Case("allocate", MemoryEffects::Allocate::get())
736             .Case("free", MemoryEffects::Free::get())
737             .Case("read", MemoryEffects::Read::get())
738             .Case("write", MemoryEffects::Write::get());
739 
740     // Check for a non-default resource to use.
741     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
742     if (effectElement.get("test_resource"))
743       resource = TestResource::get();
744 
745     // Check for a result to affect.
746     if (effectElement.get("on_result"))
747       effects.emplace_back(effect, getResult(), resource);
748     else if (Attribute ref = effectElement.get("on_reference"))
749       effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
750     else
751       effects.emplace_back(effect, resource);
752   }
753 }
754 
755 void SideEffectOp::getEffects(
756     SmallVectorImpl<TestEffects::EffectInstance> &effects) {
757   auto effectsAttr = (*this)->getAttrOfType<AffineMapAttr>("effect_parameter");
758   if (!effectsAttr)
759     return;
760 
761   effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // StringAttrPrettyNameOp
766 //===----------------------------------------------------------------------===//
767 
768 // This op has fancy handling of its SSA result name.
769 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
770                                                OperationState &result) {
771   // Add the result types.
772   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
773     result.addTypes(parser.getBuilder().getIntegerType(32));
774 
775   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
776     return failure();
777 
778   // If the attribute dictionary contains no 'names' attribute, infer it from
779   // the SSA name (if specified).
780   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
781     return attr.first == "names";
782   });
783 
784   // If there was no name specified, check to see if there was a useful name
785   // specified in the asm file.
786   if (hadNames || parser.getNumResults() == 0)
787     return success();
788 
789   SmallVector<StringRef, 4> names;
790   auto *context = result.getContext();
791 
792   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
793     auto resultName = parser.getResultName(i);
794     StringRef nameStr;
795     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
796       nameStr = resultName.first;
797 
798     names.push_back(nameStr);
799   }
800 
801   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
802   result.attributes.push_back({Identifier::get("names", context), namesAttr});
803   return success();
804 }
805 
806 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
807   p << "test.string_attr_pretty_name";
808 
809   // Note that we only need to print the "name" attribute if the asmprinter
810   // result name disagrees with it.  This can happen in strange cases, e.g.
811   // when there are conflicts.
812   bool namesDisagree = op.names().size() != op.getNumResults();
813 
814   SmallString<32> resultNameStr;
815   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
816     resultNameStr.clear();
817     llvm::raw_svector_ostream tmpStream(resultNameStr);
818     p.printOperand(op.getResult(i), tmpStream);
819 
820     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
821     if (!expectedName ||
822         tmpStream.str().drop_front() != expectedName.getValue()) {
823       namesDisagree = true;
824     }
825   }
826 
827   if (namesDisagree)
828     p.printOptionalAttrDictWithKeyword(op->getAttrs());
829   else
830     p.printOptionalAttrDictWithKeyword(op->getAttrs(), {"names"});
831 }
832 
833 // We set the SSA name in the asm syntax to the contents of the name
834 // attribute.
835 void StringAttrPrettyNameOp::getAsmResultNames(
836     function_ref<void(Value, StringRef)> setNameFn) {
837 
838   auto value = names();
839   for (size_t i = 0, e = value.size(); i != e; ++i)
840     if (auto str = value[i].dyn_cast<StringAttr>())
841       if (!str.getValue().empty())
842         setNameFn(getResult(i), str.getValue());
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // RegionIfOp
847 //===----------------------------------------------------------------------===//
848 
849 static void print(OpAsmPrinter &p, RegionIfOp op) {
850   p << RegionIfOp::getOperationName() << " ";
851   p.printOperands(op.getOperands());
852   p << ": " << op.getOperandTypes();
853   p.printArrowTypeList(op.getResultTypes());
854   p << " then";
855   p.printRegion(op.thenRegion(),
856                 /*printEntryBlockArgs=*/true,
857                 /*printBlockTerminators=*/true);
858   p << " else";
859   p.printRegion(op.elseRegion(),
860                 /*printEntryBlockArgs=*/true,
861                 /*printBlockTerminators=*/true);
862   p << " join";
863   p.printRegion(op.joinRegion(),
864                 /*printEntryBlockArgs=*/true,
865                 /*printBlockTerminators=*/true);
866 }
867 
868 static ParseResult parseRegionIfOp(OpAsmParser &parser,
869                                    OperationState &result) {
870   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
871   SmallVector<Type, 2> operandTypes;
872 
873   result.regions.reserve(3);
874   Region *thenRegion = result.addRegion();
875   Region *elseRegion = result.addRegion();
876   Region *joinRegion = result.addRegion();
877 
878   // Parse operand, type and arrow type lists.
879   if (parser.parseOperandList(operandInfos) ||
880       parser.parseColonTypeList(operandTypes) ||
881       parser.parseArrowTypeList(result.types))
882     return failure();
883 
884   // Parse all attached regions.
885   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
886       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
887       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
888     return failure();
889 
890   return parser.resolveOperands(operandInfos, operandTypes,
891                                 parser.getCurrentLocation(), result.operands);
892 }
893 
894 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
895   assert(index < 2 && "invalid region index");
896   return getOperands();
897 }
898 
899 void RegionIfOp::getSuccessorRegions(
900     Optional<unsigned> index, ArrayRef<Attribute> operands,
901     SmallVectorImpl<RegionSuccessor> &regions) {
902   // We always branch to the join region.
903   if (index.hasValue()) {
904     if (index.getValue() < 2)
905       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
906     else
907       regions.push_back(RegionSuccessor(getResults()));
908     return;
909   }
910 
911   // The then and else regions are the entry regions of this op.
912   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
913   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
914 }
915 
916 #include "TestOpEnums.cpp.inc"
917 #include "TestOpInterfaces.cpp.inc"
918 #include "TestOpStructs.cpp.inc"
919 #include "TestTypeInterfaces.cpp.inc"
920 
921 #define GET_OP_CLASSES
922 #include "TestOps.cpp.inc"
923