1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringSwitch.h"
21 
22 using namespace mlir;
23 
24 void mlir::registerTestDialect(DialectRegistry &registry) {
25   registry.insert<TestDialect>();
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // TestDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 
34 // Test support for interacting with the AsmPrinter.
35 struct TestOpAsmInterface : public OpAsmDialectInterface {
36   using OpAsmDialectInterface::OpAsmDialectInterface;
37 
38   void getAsmResultNames(Operation *op,
39                          OpAsmSetValueNameFn setNameFn) const final {
40     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
41       setNameFn(asmOp, "result");
42   }
43 
44   void getAsmBlockArgumentNames(Block *block,
45                                 OpAsmSetValueNameFn setNameFn) const final {
46     auto op = block->getParentOp();
47     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
48     if (!arrayAttr)
49       return;
50     auto args = block->getArguments();
51     auto e = std::min(arrayAttr.size(), args.size());
52     for (unsigned i = 0; i < e; ++i) {
53       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
54         setNameFn(args[i], strAttr.getValue());
55     }
56   }
57 };
58 
59 struct TestDialectFoldInterface : public DialectFoldInterface {
60   using DialectFoldInterface::DialectFoldInterface;
61 
62   /// Registered hook to check if the given region, which is attached to an
63   /// operation that is *not* isolated from above, should be used when
64   /// materializing constants.
65   bool shouldMaterializeInto(Region *region) const final {
66     // If this is a one region operation, then insert into it.
67     return isa<OneRegionOp>(region->getParentOp());
68   }
69 };
70 
71 /// This class defines the interface for handling inlining with standard
72 /// operations.
73 struct TestInlinerInterface : public DialectInlinerInterface {
74   using DialectInlinerInterface::DialectInlinerInterface;
75 
76   //===--------------------------------------------------------------------===//
77   // Analysis Hooks
78   //===--------------------------------------------------------------------===//
79 
80   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
81     // Inlining into test dialect regions is legal.
82     return true;
83   }
84   bool isLegalToInline(Operation *, Region *,
85                        BlockAndValueMapping &) const final {
86     return true;
87   }
88 
89   bool shouldAnalyzeRecursively(Operation *op) const final {
90     // Analyze recursively if this is not a functional region operation, it
91     // froms a separate functional scope.
92     return !isa<FunctionalRegionOp>(op);
93   }
94 
95   //===--------------------------------------------------------------------===//
96   // Transformation Hooks
97   //===--------------------------------------------------------------------===//
98 
99   /// Handle the given inlined terminator by replacing it with a new operation
100   /// as necessary.
101   void handleTerminator(Operation *op,
102                         ArrayRef<Value> valuesToRepl) const final {
103     // Only handle "test.return" here.
104     auto returnOp = dyn_cast<TestReturnOp>(op);
105     if (!returnOp)
106       return;
107 
108     // Replace the values directly with the return operands.
109     assert(returnOp.getNumOperands() == valuesToRepl.size());
110     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
111       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
112   }
113 
114   /// Attempt to materialize a conversion for a type mismatch between a call
115   /// from this dialect, and a callable region. This method should generate an
116   /// operation that takes 'input' as the only operand, and produces a single
117   /// result of 'resultType'. If a conversion can not be generated, nullptr
118   /// should be returned.
119   Operation *materializeCallConversion(OpBuilder &builder, Value input,
120                                        Type resultType,
121                                        Location conversionLoc) const final {
122     // Only allow conversion for i16/i32 types.
123     if (!(resultType.isSignlessInteger(16) ||
124           resultType.isSignlessInteger(32)) ||
125         !(input.getType().isSignlessInteger(16) ||
126           input.getType().isSignlessInteger(32)))
127       return nullptr;
128     return builder.create<TestCastOp>(conversionLoc, resultType, input);
129   }
130 };
131 } // end anonymous namespace
132 
133 //===----------------------------------------------------------------------===//
134 // TestDialect
135 //===----------------------------------------------------------------------===//
136 
137 void TestDialect::initialize() {
138   addOperations<
139 #define GET_OP_LIST
140 #include "TestOps.cpp.inc"
141       >();
142   addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
143                 TestInlinerInterface>();
144   addTypes<TestType, TestRecursiveType>();
145   allowUnknownOperations();
146 }
147 
148 static Type parseTestType(DialectAsmParser &parser,
149                           llvm::SetVector<Type> &stack) {
150   StringRef typeTag;
151   if (failed(parser.parseKeyword(&typeTag)))
152     return Type();
153 
154   if (typeTag == "test_type")
155     return TestType::get(parser.getBuilder().getContext());
156 
157   if (typeTag != "test_rec")
158     return Type();
159 
160   StringRef name;
161   if (parser.parseLess() || parser.parseKeyword(&name))
162     return Type();
163   auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name);
164 
165   // If this type already has been parsed above in the stack, expect just the
166   // name.
167   if (stack.contains(rec)) {
168     if (failed(parser.parseGreater()))
169       return Type();
170     return rec;
171   }
172 
173   // Otherwise, parse the body and update the type.
174   if (failed(parser.parseComma()))
175     return Type();
176   stack.insert(rec);
177   Type subtype = parseTestType(parser, stack);
178   stack.pop_back();
179   if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
180     return Type();
181 
182   return rec;
183 }
184 
185 Type TestDialect::parseType(DialectAsmParser &parser) const {
186   llvm::SetVector<Type> stack;
187   return parseTestType(parser, stack);
188 }
189 
190 static void printTestType(Type type, DialectAsmPrinter &printer,
191                           llvm::SetVector<Type> &stack) {
192   if (type.isa<TestType>()) {
193     printer << "test_type";
194     return;
195   }
196 
197   auto rec = type.cast<TestRecursiveType>();
198   printer << "test_rec<" << rec.getName();
199   if (!stack.contains(rec)) {
200     printer << ", ";
201     stack.insert(rec);
202     printTestType(rec.getBody(), printer, stack);
203     stack.pop_back();
204   }
205   printer << ">";
206 }
207 
208 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
209   llvm::SetVector<Type> stack;
210   printTestType(type, printer, stack);
211 }
212 
213 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
214                                                     NamedAttribute namedAttr) {
215   if (namedAttr.first == "test.invalid_attr")
216     return op->emitError() << "invalid to use 'test.invalid_attr'";
217   return success();
218 }
219 
220 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
221                                                     unsigned regionIndex,
222                                                     unsigned argIndex,
223                                                     NamedAttribute namedAttr) {
224   if (namedAttr.first == "test.invalid_attr")
225     return op->emitError() << "invalid to use 'test.invalid_attr'";
226   return success();
227 }
228 
229 LogicalResult
230 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
231                                          unsigned resultIndex,
232                                          NamedAttribute namedAttr) {
233   if (namedAttr.first == "test.invalid_attr")
234     return op->emitError() << "invalid to use 'test.invalid_attr'";
235   return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // TestBranchOp
240 //===----------------------------------------------------------------------===//
241 
242 Optional<MutableOperandRange>
243 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
244   assert(index == 0 && "invalid successor index");
245   return targetOperandsMutable();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // TestFoldToCallOp
250 //===----------------------------------------------------------------------===//
251 
252 namespace {
253 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
254   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
255 
256   LogicalResult matchAndRewrite(FoldToCallOp op,
257                                 PatternRewriter &rewriter) const override {
258     rewriter.replaceOpWithNewOp<CallOp>(op, TypeRange(), op.calleeAttr(),
259                                         ValueRange());
260     return success();
261   }
262 };
263 } // end anonymous namespace
264 
265 void FoldToCallOp::getCanonicalizationPatterns(
266     OwningRewritePatternList &results, MLIRContext *context) {
267   results.insert<FoldToCallOpPattern>(context);
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // Test Format* operations
272 //===----------------------------------------------------------------------===//
273 
274 //===----------------------------------------------------------------------===//
275 // Parsing
276 
277 static ParseResult parseCustomDirectiveOperands(
278     OpAsmParser &parser, OpAsmParser::OperandType &operand,
279     Optional<OpAsmParser::OperandType> &optOperand,
280     SmallVectorImpl<OpAsmParser::OperandType> &varOperands) {
281   if (parser.parseOperand(operand))
282     return failure();
283   if (succeeded(parser.parseOptionalComma())) {
284     optOperand.emplace();
285     if (parser.parseOperand(*optOperand))
286       return failure();
287   }
288   if (parser.parseArrow() || parser.parseLParen() ||
289       parser.parseOperandList(varOperands) || parser.parseRParen())
290     return failure();
291   return success();
292 }
293 static ParseResult
294 parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
295                             Type &optOperandType,
296                             SmallVectorImpl<Type> &varOperandTypes) {
297   if (parser.parseColon())
298     return failure();
299 
300   if (parser.parseType(operandType))
301     return failure();
302   if (succeeded(parser.parseOptionalComma())) {
303     if (parser.parseType(optOperandType))
304       return failure();
305   }
306   if (parser.parseArrow() || parser.parseLParen() ||
307       parser.parseTypeList(varOperandTypes) || parser.parseRParen())
308     return failure();
309   return success();
310 }
311 static ParseResult
312 parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
313                                  Type optOperandType,
314                                  const SmallVectorImpl<Type> &varOperandTypes) {
315   if (parser.parseKeyword("type_refs_capture"))
316     return failure();
317 
318   Type operandType2, optOperandType2;
319   SmallVector<Type, 1> varOperandTypes2;
320   if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
321                                   varOperandTypes2))
322     return failure();
323 
324   if (operandType != operandType2 || optOperandType != optOperandType2 ||
325       varOperandTypes != varOperandTypes2)
326     return failure();
327 
328   return success();
329 }
330 static ParseResult parseCustomDirectiveOperandsAndTypes(
331     OpAsmParser &parser, OpAsmParser::OperandType &operand,
332     Optional<OpAsmParser::OperandType> &optOperand,
333     SmallVectorImpl<OpAsmParser::OperandType> &varOperands, Type &operandType,
334     Type &optOperandType, SmallVectorImpl<Type> &varOperandTypes) {
335   if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
336       parseCustomDirectiveResults(parser, operandType, optOperandType,
337                                   varOperandTypes))
338     return failure();
339   return success();
340 }
341 static ParseResult parseCustomDirectiveRegions(
342     OpAsmParser &parser, Region &region,
343     SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
344   if (parser.parseRegion(region))
345     return failure();
346   if (failed(parser.parseOptionalComma()))
347     return success();
348   std::unique_ptr<Region> varRegion = std::make_unique<Region>();
349   if (parser.parseRegion(*varRegion))
350     return failure();
351   varRegions.emplace_back(std::move(varRegion));
352   return success();
353 }
354 static ParseResult
355 parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
356                                SmallVectorImpl<Block *> &varSuccessors) {
357   if (parser.parseSuccessor(successor))
358     return failure();
359   if (failed(parser.parseOptionalComma()))
360     return success();
361   Block *varSuccessor;
362   if (parser.parseSuccessor(varSuccessor))
363     return failure();
364   varSuccessors.append(2, varSuccessor);
365   return success();
366 }
367 static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
368                                                   IntegerAttr &attr,
369                                                   IntegerAttr &optAttr) {
370   if (parser.parseAttribute(attr))
371     return failure();
372   if (succeeded(parser.parseOptionalComma())) {
373     if (parser.parseAttribute(optAttr))
374       return failure();
375   }
376   return success();
377 }
378 
379 //===----------------------------------------------------------------------===//
380 // Printing
381 
382 static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand,
383                                          Value optOperand,
384                                          OperandRange varOperands) {
385   printer << operand;
386   if (optOperand)
387     printer << ", " << optOperand;
388   printer << " -> (" << varOperands << ")";
389 }
390 static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType,
391                                         Type optOperandType,
392                                         TypeRange varOperandTypes) {
393   printer << " : " << operandType;
394   if (optOperandType)
395     printer << ", " << optOperandType;
396   printer << " -> (" << varOperandTypes << ")";
397 }
398 static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
399                                              Type operandType,
400                                              Type optOperandType,
401                                              TypeRange varOperandTypes) {
402   printer << " type_refs_capture ";
403   printCustomDirectiveResults(printer, operandType, optOperandType,
404                               varOperandTypes);
405 }
406 static void
407 printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand,
408                                      Value optOperand, OperandRange varOperands,
409                                      Type operandType, Type optOperandType,
410                                      TypeRange varOperandTypes) {
411   printCustomDirectiveOperands(printer, operand, optOperand, varOperands);
412   printCustomDirectiveResults(printer, operandType, optOperandType,
413                               varOperandTypes);
414 }
415 static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region &region,
416                                         MutableArrayRef<Region> varRegions) {
417   printer.printRegion(region);
418   if (!varRegions.empty()) {
419     printer << ", ";
420     for (Region &region : varRegions)
421       printer.printRegion(region);
422   }
423 }
424 static void printCustomDirectiveSuccessors(OpAsmPrinter &printer,
425                                            Block *successor,
426                                            SuccessorRange varSuccessors) {
427   printer << successor;
428   if (!varSuccessors.empty())
429     printer << ", " << varSuccessors.front();
430 }
431 static void printCustomDirectiveAttributes(OpAsmPrinter &printer,
432                                            Attribute attribute,
433                                            Attribute optAttribute) {
434   printer << attribute;
435   if (optAttribute)
436     printer << ", " << optAttribute;
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // Test IsolatedRegionOp - parse passthrough region arguments.
441 //===----------------------------------------------------------------------===//
442 
443 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
444                                          OperationState &result) {
445   OpAsmParser::OperandType argInfo;
446   Type argType = parser.getBuilder().getIndexType();
447 
448   // Parse the input operand.
449   if (parser.parseOperand(argInfo) ||
450       parser.resolveOperand(argInfo, argType, result.operands))
451     return failure();
452 
453   // Parse the body region, and reuse the operand info as the argument info.
454   Region *body = result.addRegion();
455   return parser.parseRegion(*body, argInfo, argType,
456                             /*enableNameShadowing=*/true);
457 }
458 
459 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
460   p << "test.isolated_region ";
461   p.printOperand(op.getOperand());
462   p.shadowRegionArgs(op.region(), op.getOperand());
463   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // Test SSACFGRegionOp
468 //===----------------------------------------------------------------------===//
469 
470 RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
471   return RegionKind::SSACFG;
472 }
473 
474 //===----------------------------------------------------------------------===//
475 // Test GraphRegionOp
476 //===----------------------------------------------------------------------===//
477 
478 static ParseResult parseGraphRegionOp(OpAsmParser &parser,
479                                       OperationState &result) {
480   // Parse the body region, and reuse the operand info as the argument info.
481   Region *body = result.addRegion();
482   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
483 }
484 
485 static void print(OpAsmPrinter &p, GraphRegionOp op) {
486   p << "test.graph_region ";
487   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
488 }
489 
490 RegionKind GraphRegionOp::getRegionKind(unsigned index) {
491   return RegionKind::Graph;
492 }
493 
494 //===----------------------------------------------------------------------===//
495 // Test AffineScopeOp
496 //===----------------------------------------------------------------------===//
497 
498 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
499                                       OperationState &result) {
500   // Parse the body region, and reuse the operand info as the argument info.
501   Region *body = result.addRegion();
502   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
503 }
504 
505 static void print(OpAsmPrinter &p, AffineScopeOp op) {
506   p << "test.affine_scope ";
507   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // Test parser.
512 //===----------------------------------------------------------------------===//
513 
514 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
515                                          OperationState &result) {
516   StringRef keyword;
517   if (parser.parseKeyword(&keyword))
518     return failure();
519   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
520   return success();
521 }
522 
523 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
524   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
529 
530 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
531                                          OperationState &result) {
532   if (parser.parseKeyword("wraps"))
533     return failure();
534 
535   // Parse the wrapped op in a region
536   Region &body = *result.addRegion();
537   body.push_back(new Block);
538   Block &block = body.back();
539   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
540   if (!wrapped_op)
541     return failure();
542 
543   // Create a return terminator in the inner region, pass as operand to the
544   // terminator the returned values from the wrapped operation.
545   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
546   OpBuilder builder(parser.getBuilder().getContext());
547   builder.setInsertionPointToEnd(&block);
548   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
549 
550   // Get the results type for the wrapping op from the terminator operands.
551   Operation &return_op = body.back().back();
552   result.types.append(return_op.operand_type_begin(),
553                       return_op.operand_type_end());
554 
555   // Use the location of the wrapped op for the "test.wrapping_region" op.
556   result.location = wrapped_op->getLoc();
557 
558   return success();
559 }
560 
561 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
562   p << op.getOperationName() << " wraps ";
563   p.printGenericOp(&op.region().front().front());
564 }
565 
566 //===----------------------------------------------------------------------===//
567 // Test PolyForOp - parse list of region arguments.
568 //===----------------------------------------------------------------------===//
569 
570 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
571   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
572   // Parse list of region arguments without a delimiter.
573   if (parser.parseRegionArgumentList(ivsInfo))
574     return failure();
575 
576   // Parse the body region.
577   Region *body = result.addRegion();
578   auto &builder = parser.getBuilder();
579   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
580   return parser.parseRegion(*body, ivsInfo, argTypes);
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // Test removing op with inner ops.
585 //===----------------------------------------------------------------------===//
586 
587 namespace {
588 struct TestRemoveOpWithInnerOps
589     : public OpRewritePattern<TestOpWithRegionPattern> {
590   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
591 
592   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
593                                 PatternRewriter &rewriter) const override {
594     rewriter.eraseOp(op);
595     return success();
596   }
597 };
598 } // end anonymous namespace
599 
600 void TestOpWithRegionPattern::getCanonicalizationPatterns(
601     OwningRewritePatternList &results, MLIRContext *context) {
602   results.insert<TestRemoveOpWithInnerOps>(context);
603 }
604 
605 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
606   return operand();
607 }
608 
609 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
610     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
611   for (Value input : this->operands()) {
612     results.push_back(input);
613   }
614   return success();
615 }
616 
617 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
618   assert(operands.size() == 1);
619   if (operands.front()) {
620     setAttr("attr", operands.front());
621     return getResult();
622   }
623   return {};
624 }
625 
626 LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
627     MLIRContext *, Optional<Location> location, ValueRange operands,
628     DictionaryAttr attributes, RegionRange regions,
629     SmallVectorImpl<Type> &inferredReturnTypes) {
630   if (operands[0].getType() != operands[1].getType()) {
631     return emitOptionalError(location, "operand type mismatch ",
632                              operands[0].getType(), " vs ",
633                              operands[1].getType());
634   }
635   inferredReturnTypes.assign({operands[0].getType()});
636   return success();
637 }
638 
639 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
640     MLIRContext *context, Optional<Location> location, ValueRange operands,
641     DictionaryAttr attributes, RegionRange regions,
642     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
643   // Create return type consisting of the last element of the first operand.
644   auto operandType = *operands.getTypes().begin();
645   auto sval = operandType.dyn_cast<ShapedType>();
646   if (!sval) {
647     return emitOptionalError(location, "only shaped type operands allowed");
648   }
649   int64_t dim =
650       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
651   auto type = IntegerType::get(17, context);
652   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
653   return success();
654 }
655 
656 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
657     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
658   shapes = SmallVector<Value, 1>{
659       builder.createOrFold<DimOp>(getLoc(), getOperand(0), 0)};
660   return success();
661 }
662 
663 //===----------------------------------------------------------------------===//
664 // Test SideEffect interfaces
665 //===----------------------------------------------------------------------===//
666 
667 namespace {
668 /// A test resource for side effects.
669 struct TestResource : public SideEffects::Resource::Base<TestResource> {
670   StringRef getName() final { return "<Test>"; }
671 };
672 } // end anonymous namespace
673 
674 void SideEffectOp::getEffects(
675     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
676   // Check for an effects attribute on the op instance.
677   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
678   if (!effectsAttr)
679     return;
680 
681   // If there is one, it is an array of dictionary attributes that hold
682   // information on the effects of this operation.
683   for (Attribute element : effectsAttr) {
684     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
685 
686     // Get the specific memory effect.
687     MemoryEffects::Effect *effect =
688         llvm::StringSwitch<MemoryEffects::Effect *>(
689             effectElement.get("effect").cast<StringAttr>().getValue())
690             .Case("allocate", MemoryEffects::Allocate::get())
691             .Case("free", MemoryEffects::Free::get())
692             .Case("read", MemoryEffects::Read::get())
693             .Case("write", MemoryEffects::Write::get());
694 
695     // Check for a result to affect.
696     Value value;
697     if (effectElement.get("on_result"))
698       value = getResult();
699 
700     // Check for a non-default resource to use.
701     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
702     if (effectElement.get("test_resource"))
703       resource = TestResource::get();
704 
705     effects.emplace_back(effect, value, resource);
706   }
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // StringAttrPrettyNameOp
711 //===----------------------------------------------------------------------===//
712 
713 // This op has fancy handling of its SSA result name.
714 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
715                                                OperationState &result) {
716   // Add the result types.
717   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
718     result.addTypes(parser.getBuilder().getIntegerType(32));
719 
720   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
721     return failure();
722 
723   // If the attribute dictionary contains no 'names' attribute, infer it from
724   // the SSA name (if specified).
725   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
726     return attr.first == "names";
727   });
728 
729   // If there was no name specified, check to see if there was a useful name
730   // specified in the asm file.
731   if (hadNames || parser.getNumResults() == 0)
732     return success();
733 
734   SmallVector<StringRef, 4> names;
735   auto *context = result.getContext();
736 
737   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
738     auto resultName = parser.getResultName(i);
739     StringRef nameStr;
740     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
741       nameStr = resultName.first;
742 
743     names.push_back(nameStr);
744   }
745 
746   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
747   result.attributes.push_back({Identifier::get("names", context), namesAttr});
748   return success();
749 }
750 
751 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
752   p << "test.string_attr_pretty_name";
753 
754   // Note that we only need to print the "name" attribute if the asmprinter
755   // result name disagrees with it.  This can happen in strange cases, e.g.
756   // when there are conflicts.
757   bool namesDisagree = op.names().size() != op.getNumResults();
758 
759   SmallString<32> resultNameStr;
760   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
761     resultNameStr.clear();
762     llvm::raw_svector_ostream tmpStream(resultNameStr);
763     p.printOperand(op.getResult(i), tmpStream);
764 
765     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
766     if (!expectedName ||
767         tmpStream.str().drop_front() != expectedName.getValue()) {
768       namesDisagree = true;
769     }
770   }
771 
772   if (namesDisagree)
773     p.printOptionalAttrDictWithKeyword(op.getAttrs());
774   else
775     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
776 }
777 
778 // We set the SSA name in the asm syntax to the contents of the name
779 // attribute.
780 void StringAttrPrettyNameOp::getAsmResultNames(
781     function_ref<void(Value, StringRef)> setNameFn) {
782 
783   auto value = names();
784   for (size_t i = 0, e = value.size(); i != e; ++i)
785     if (auto str = value[i].dyn_cast<StringAttr>())
786       if (!str.getValue().empty())
787         setNameFn(getResult(i), str.getValue());
788 }
789 
790 //===----------------------------------------------------------------------===//
791 // RegionIfOp
792 //===----------------------------------------------------------------------===//
793 
794 static void print(OpAsmPrinter &p, RegionIfOp op) {
795   p << RegionIfOp::getOperationName() << " ";
796   p.printOperands(op.getOperands());
797   p << ": " << op.getOperandTypes();
798   p.printArrowTypeList(op.getResultTypes());
799   p << " then";
800   p.printRegion(op.thenRegion(),
801                 /*printEntryBlockArgs=*/true,
802                 /*printBlockTerminators=*/true);
803   p << " else";
804   p.printRegion(op.elseRegion(),
805                 /*printEntryBlockArgs=*/true,
806                 /*printBlockTerminators=*/true);
807   p << " join";
808   p.printRegion(op.joinRegion(),
809                 /*printEntryBlockArgs=*/true,
810                 /*printBlockTerminators=*/true);
811 }
812 
813 static ParseResult parseRegionIfOp(OpAsmParser &parser,
814                                    OperationState &result) {
815   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
816   SmallVector<Type, 2> operandTypes;
817 
818   result.regions.reserve(3);
819   Region *thenRegion = result.addRegion();
820   Region *elseRegion = result.addRegion();
821   Region *joinRegion = result.addRegion();
822 
823   // Parse operand, type and arrow type lists.
824   if (parser.parseOperandList(operandInfos) ||
825       parser.parseColonTypeList(operandTypes) ||
826       parser.parseArrowTypeList(result.types))
827     return failure();
828 
829   // Parse all attached regions.
830   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
831       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
832       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
833     return failure();
834 
835   return parser.resolveOperands(operandInfos, operandTypes,
836                                 parser.getCurrentLocation(), result.operands);
837 }
838 
839 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
840   assert(index < 2 && "invalid region index");
841   return getOperands();
842 }
843 
844 void RegionIfOp::getSuccessorRegions(
845     Optional<unsigned> index, ArrayRef<Attribute> operands,
846     SmallVectorImpl<RegionSuccessor> &regions) {
847   // We always branch to the join region.
848   if (index.hasValue()) {
849     if (index.getValue() < 2)
850       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
851     else
852       regions.push_back(RegionSuccessor(getResults()));
853     return;
854   }
855 
856   // The then and else regions are the entry regions of this op.
857   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
858   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
859 }
860 
861 #include "TestOpEnums.cpp.inc"
862 #include "TestOpStructs.cpp.inc"
863 #include "TestTypeInterfaces.cpp.inc"
864 
865 #define GET_OP_CLASSES
866 #include "TestOps.cpp.inc"
867