1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "TestTypes.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Function.h"
14 #include "mlir/IR/Module.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/IR/TypeUtilities.h"
17 #include "mlir/Transforms/FoldUtils.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/StringSwitch.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // TestDialect Interfaces
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 
29 // Test support for interacting with the AsmPrinter.
30 struct TestOpAsmInterface : public OpAsmDialectInterface {
31   using OpAsmDialectInterface::OpAsmDialectInterface;
32 
33   void getAsmResultNames(Operation *op,
34                          OpAsmSetValueNameFn setNameFn) const final {
35     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
36       setNameFn(asmOp, "result");
37   }
38 
39   void getAsmBlockArgumentNames(Block *block,
40                                 OpAsmSetValueNameFn setNameFn) const final {
41     auto op = block->getParentOp();
42     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
43     if (!arrayAttr)
44       return;
45     auto args = block->getArguments();
46     auto e = std::min(arrayAttr.size(), args.size());
47     for (unsigned i = 0; i < e; ++i) {
48       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
49         setNameFn(args[i], strAttr.getValue());
50     }
51   }
52 };
53 
54 struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
55   using OpFolderDialectInterface::OpFolderDialectInterface;
56 
57   /// Registered hook to check if the given region, which is attached to an
58   /// operation that is *not* isolated from above, should be used when
59   /// materializing constants.
60   bool shouldMaterializeInto(Region *region) const final {
61     // If this is a one region operation, then insert into it.
62     return isa<OneRegionOp>(region->getParentOp());
63   }
64 };
65 
66 /// This class defines the interface for handling inlining with standard
67 /// operations.
68 struct TestInlinerInterface : public DialectInlinerInterface {
69   using DialectInlinerInterface::DialectInlinerInterface;
70 
71   //===--------------------------------------------------------------------===//
72   // Analysis Hooks
73   //===--------------------------------------------------------------------===//
74 
75   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
76     // Inlining into test dialect regions is legal.
77     return true;
78   }
79   bool isLegalToInline(Operation *, Region *,
80                        BlockAndValueMapping &) const final {
81     return true;
82   }
83 
84   bool shouldAnalyzeRecursively(Operation *op) const final {
85     // Analyze recursively if this is not a functional region operation, it
86     // froms a separate functional scope.
87     return !isa<FunctionalRegionOp>(op);
88   }
89 
90   //===--------------------------------------------------------------------===//
91   // Transformation Hooks
92   //===--------------------------------------------------------------------===//
93 
94   /// Handle the given inlined terminator by replacing it with a new operation
95   /// as necessary.
96   void handleTerminator(Operation *op,
97                         ArrayRef<Value> valuesToRepl) const final {
98     // Only handle "test.return" here.
99     auto returnOp = dyn_cast<TestReturnOp>(op);
100     if (!returnOp)
101       return;
102 
103     // Replace the values directly with the return operands.
104     assert(returnOp.getNumOperands() == valuesToRepl.size());
105     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
106       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
107   }
108 
109   /// Attempt to materialize a conversion for a type mismatch between a call
110   /// from this dialect, and a callable region. This method should generate an
111   /// operation that takes 'input' as the only operand, and produces a single
112   /// result of 'resultType'. If a conversion can not be generated, nullptr
113   /// should be returned.
114   Operation *materializeCallConversion(OpBuilder &builder, Value input,
115                                        Type resultType,
116                                        Location conversionLoc) const final {
117     // Only allow conversion for i16/i32 types.
118     if (!(resultType.isSignlessInteger(16) ||
119           resultType.isSignlessInteger(32)) ||
120         !(input.getType().isSignlessInteger(16) ||
121           input.getType().isSignlessInteger(32)))
122       return nullptr;
123     return builder.create<TestCastOp>(conversionLoc, resultType, input);
124   }
125 };
126 } // end anonymous namespace
127 
128 //===----------------------------------------------------------------------===//
129 // TestDialect
130 //===----------------------------------------------------------------------===//
131 
132 TestDialect::TestDialect(MLIRContext *context)
133     : Dialect(getDialectNamespace(), context) {
134   addOperations<
135 #define GET_OP_LIST
136 #include "TestOps.cpp.inc"
137       >();
138   addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
139                 TestInlinerInterface>();
140   addTypes<TestType>();
141   allowUnknownOperations();
142 }
143 
144 Type TestDialect::parseType(DialectAsmParser &parser) const {
145   if (failed(parser.parseKeyword("test_type")))
146     return Type();
147   return TestType::get(getContext());
148 }
149 
150 void TestDialect::printType(Type type, DialectAsmPrinter &printer) const {
151   assert(type.isa<TestType>() && "unexpected type");
152   printer << "test_type";
153 }
154 
155 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
156                                                     NamedAttribute namedAttr) {
157   if (namedAttr.first == "test.invalid_attr")
158     return op->emitError() << "invalid to use 'test.invalid_attr'";
159   return success();
160 }
161 
162 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
163                                                     unsigned regionIndex,
164                                                     unsigned argIndex,
165                                                     NamedAttribute namedAttr) {
166   if (namedAttr.first == "test.invalid_attr")
167     return op->emitError() << "invalid to use 'test.invalid_attr'";
168   return success();
169 }
170 
171 LogicalResult
172 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
173                                          unsigned resultIndex,
174                                          NamedAttribute namedAttr) {
175   if (namedAttr.first == "test.invalid_attr")
176     return op->emitError() << "invalid to use 'test.invalid_attr'";
177   return success();
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // TestBranchOp
182 //===----------------------------------------------------------------------===//
183 
184 Optional<MutableOperandRange>
185 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
186   assert(index == 0 && "invalid successor index");
187   return targetOperandsMutable();
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // TestFoldToCallOp
192 //===----------------------------------------------------------------------===//
193 
194 namespace {
195 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
196   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
197 
198   LogicalResult matchAndRewrite(FoldToCallOp op,
199                                 PatternRewriter &rewriter) const override {
200     rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
201                                         ValueRange());
202     return success();
203   }
204 };
205 } // end anonymous namespace
206 
207 void FoldToCallOp::getCanonicalizationPatterns(
208     OwningRewritePatternList &results, MLIRContext *context) {
209   results.insert<FoldToCallOpPattern>(context);
210 }
211 
212 //===----------------------------------------------------------------------===//
213 // Test IsolatedRegionOp - parse passthrough region arguments.
214 //===----------------------------------------------------------------------===//
215 
216 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
217                                          OperationState &result) {
218   OpAsmParser::OperandType argInfo;
219   Type argType = parser.getBuilder().getIndexType();
220 
221   // Parse the input operand.
222   if (parser.parseOperand(argInfo) ||
223       parser.resolveOperand(argInfo, argType, result.operands))
224     return failure();
225 
226   // Parse the body region, and reuse the operand info as the argument info.
227   Region *body = result.addRegion();
228   return parser.parseRegion(*body, argInfo, argType,
229                             /*enableNameShadowing=*/true);
230 }
231 
232 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
233   p << "test.isolated_region ";
234   p.printOperand(op.getOperand());
235   p.shadowRegionArgs(op.region(), op.getOperand());
236   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Test AffineScopeOp
241 //===----------------------------------------------------------------------===//
242 
243 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
244                                       OperationState &result) {
245   // Parse the body region, and reuse the operand info as the argument info.
246   Region *body = result.addRegion();
247   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
248 }
249 
250 static void print(OpAsmPrinter &p, AffineScopeOp op) {
251   p << "test.affine_scope ";
252   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Test parser.
257 //===----------------------------------------------------------------------===//
258 
259 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
260                                          OperationState &result) {
261   StringRef keyword;
262   if (parser.parseKeyword(&keyword))
263     return failure();
264   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
265   return success();
266 }
267 
268 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
269   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
274 
275 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
276                                          OperationState &result) {
277   if (parser.parseKeyword("wraps"))
278     return failure();
279 
280   // Parse the wrapped op in a region
281   Region &body = *result.addRegion();
282   body.push_back(new Block);
283   Block &block = body.back();
284   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
285   if (!wrapped_op)
286     return failure();
287 
288   // Create a return terminator in the inner region, pass as operand to the
289   // terminator the returned values from the wrapped operation.
290   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
291   OpBuilder builder(parser.getBuilder().getContext());
292   builder.setInsertionPointToEnd(&block);
293   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
294 
295   // Get the results type for the wrapping op from the terminator operands.
296   Operation &return_op = body.back().back();
297   result.types.append(return_op.operand_type_begin(),
298                       return_op.operand_type_end());
299 
300   // Use the location of the wrapped op for the "test.wrapping_region" op.
301   result.location = wrapped_op->getLoc();
302 
303   return success();
304 }
305 
306 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
307   p << op.getOperationName() << " wraps ";
308   p.printGenericOp(&op.region().front().front());
309 }
310 
311 //===----------------------------------------------------------------------===//
312 // Test PolyForOp - parse list of region arguments.
313 //===----------------------------------------------------------------------===//
314 
315 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
316   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
317   // Parse list of region arguments without a delimiter.
318   if (parser.parseRegionArgumentList(ivsInfo))
319     return failure();
320 
321   // Parse the body region.
322   Region *body = result.addRegion();
323   auto &builder = parser.getBuilder();
324   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
325   return parser.parseRegion(*body, ivsInfo, argTypes);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 // Test removing op with inner ops.
330 //===----------------------------------------------------------------------===//
331 
332 namespace {
333 struct TestRemoveOpWithInnerOps
334     : public OpRewritePattern<TestOpWithRegionPattern> {
335   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
336 
337   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
338                                 PatternRewriter &rewriter) const override {
339     rewriter.eraseOp(op);
340     return success();
341   }
342 };
343 } // end anonymous namespace
344 
345 void TestOpWithRegionPattern::getCanonicalizationPatterns(
346     OwningRewritePatternList &results, MLIRContext *context) {
347   results.insert<TestRemoveOpWithInnerOps>(context);
348 }
349 
350 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
351   return operand();
352 }
353 
354 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
355     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
356   for (Value input : this->operands()) {
357     results.push_back(input);
358   }
359   return success();
360 }
361 
362 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
363   assert(operands.size() == 1);
364   if (operands.front()) {
365     setAttr("attr", operands.front());
366     return getResult();
367   }
368   return {};
369 }
370 
371 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
372     MLIRContext *, Optional<Location> location, ValueRange operands,
373     DictionaryAttr attributes, RegionRange regions,
374     SmallVectorImpl<Type> &inferredReturnTypes) {
375   if (operands[0].getType() != operands[1].getType()) {
376     return emitOptionalError(location, "operand type mismatch ",
377                              operands[0].getType(), " vs ",
378                              operands[1].getType());
379   }
380   inferredReturnTypes.assign({operands[0].getType()});
381   return success();
382 }
383 
384 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
385     MLIRContext *context, Optional<Location> location, ValueRange operands,
386     DictionaryAttr attributes, RegionRange regions,
387     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
388   // Create return type consisting of the last element of the first operand.
389   auto operandType = *operands.getTypes().begin();
390   auto sval = operandType.dyn_cast<ShapedType>();
391   if (!sval) {
392     return emitOptionalError(location, "only shaped type operands allowed");
393   }
394   int64_t dim =
395       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
396   auto type = IntegerType::get(17, context);
397   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
398   return success();
399 }
400 
401 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
402     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
403   shapes = SmallVector<Value, 1>{
404       builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
405   return success();
406 }
407 
408 //===----------------------------------------------------------------------===//
409 // Test SideEffect interfaces
410 //===----------------------------------------------------------------------===//
411 
412 namespace {
413 /// A test resource for side effects.
414 struct TestResource : public SideEffects::Resource::Base<TestResource> {
415   StringRef getName() final { return "<Test>"; }
416 };
417 } // end anonymous namespace
418 
419 void SideEffectOp::getEffects(
420     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
421   // Check for an effects attribute on the op instance.
422   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
423   if (!effectsAttr)
424     return;
425 
426   // If there is one, it is an array of dictionary attributes that hold
427   // information on the effects of this operation.
428   for (Attribute element : effectsAttr) {
429     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
430 
431     // Get the specific memory effect.
432     MemoryEffects::Effect *effect =
433         llvm::StringSwitch<MemoryEffects::Effect *>(
434             effectElement.get("effect").cast<StringAttr>().getValue())
435             .Case("allocate", MemoryEffects::Allocate::get())
436             .Case("free", MemoryEffects::Free::get())
437             .Case("read", MemoryEffects::Read::get())
438             .Case("write", MemoryEffects::Write::get());
439 
440     // Check for a result to affect.
441     Value value;
442     if (effectElement.get("on_result"))
443       value = getResult();
444 
445     // Check for a non-default resource to use.
446     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
447     if (effectElement.get("test_resource"))
448       resource = TestResource::get();
449 
450     effects.emplace_back(effect, value, resource);
451   }
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // StringAttrPrettyNameOp
456 //===----------------------------------------------------------------------===//
457 
458 // This op has fancy handling of its SSA result name.
459 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
460                                                OperationState &result) {
461   // Add the result types.
462   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
463     result.addTypes(parser.getBuilder().getIntegerType(32));
464 
465   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
466     return failure();
467 
468   // If the attribute dictionary contains no 'names' attribute, infer it from
469   // the SSA name (if specified).
470   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
471     return attr.first == "names";
472   });
473 
474   // If there was no name specified, check to see if there was a useful name
475   // specified in the asm file.
476   if (hadNames || parser.getNumResults() == 0)
477     return success();
478 
479   SmallVector<StringRef, 4> names;
480   auto *context = result.getContext();
481 
482   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
483     auto resultName = parser.getResultName(i);
484     StringRef nameStr;
485     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
486       nameStr = resultName.first;
487 
488     names.push_back(nameStr);
489   }
490 
491   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
492   result.attributes.push_back({Identifier::get("names", context), namesAttr});
493   return success();
494 }
495 
496 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
497   p << "test.string_attr_pretty_name";
498 
499   // Note that we only need to print the "name" attribute if the asmprinter
500   // result name disagrees with it.  This can happen in strange cases, e.g.
501   // when there are conflicts.
502   bool namesDisagree = op.names().size() != op.getNumResults();
503 
504   SmallString<32> resultNameStr;
505   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
506     resultNameStr.clear();
507     llvm::raw_svector_ostream tmpStream(resultNameStr);
508     p.printOperand(op.getResult(i), tmpStream);
509 
510     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
511     if (!expectedName ||
512         tmpStream.str().drop_front() != expectedName.getValue()) {
513       namesDisagree = true;
514     }
515   }
516 
517   if (namesDisagree)
518     p.printOptionalAttrDictWithKeyword(op.getAttrs());
519   else
520     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
521 }
522 
523 // We set the SSA name in the asm syntax to the contents of the name
524 // attribute.
525 void StringAttrPrettyNameOp::getAsmResultNames(
526     function_ref<void(Value, StringRef)> setNameFn) {
527 
528   auto value = names();
529   for (size_t i = 0, e = value.size(); i != e; ++i)
530     if (auto str = value[i].dyn_cast<StringAttr>())
531       if (!str.getValue().empty())
532         setNameFn(getResult(i), str.getValue());
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // RegionIfOp
537 //===----------------------------------------------------------------------===//
538 
539 static void print(OpAsmPrinter &p, RegionIfOp op) {
540   p << RegionIfOp::getOperationName() << " ";
541   p.printOperands(op.getOperands());
542   p << ": " << op.getOperandTypes();
543   p.printArrowTypeList(op.getResultTypes());
544   p << " then";
545   p.printRegion(op.thenRegion(),
546                 /*printEntryBlockArgs=*/true,
547                 /*printBlockTerminators=*/true);
548   p << " else";
549   p.printRegion(op.elseRegion(),
550                 /*printEntryBlockArgs=*/true,
551                 /*printBlockTerminators=*/true);
552   p << " join";
553   p.printRegion(op.joinRegion(),
554                 /*printEntryBlockArgs=*/true,
555                 /*printBlockTerminators=*/true);
556 }
557 
558 static ParseResult parseRegionIfOp(OpAsmParser &parser,
559                                    OperationState &result) {
560   SmallVector<OpAsmParser::OperandType, 2> operandInfos;
561   SmallVector<Type, 2> operandTypes;
562 
563   result.regions.reserve(3);
564   Region *thenRegion = result.addRegion();
565   Region *elseRegion = result.addRegion();
566   Region *joinRegion = result.addRegion();
567 
568   // Parse operand, type and arrow type lists.
569   if (parser.parseOperandList(operandInfos) ||
570       parser.parseColonTypeList(operandTypes) ||
571       parser.parseArrowTypeList(result.types))
572     return failure();
573 
574   // Parse all attached regions.
575   if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
576       parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
577       parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
578     return failure();
579 
580   return parser.resolveOperands(operandInfos, operandTypes,
581                                 parser.getCurrentLocation(), result.operands);
582 }
583 
584 OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) {
585   assert(index < 2 && "invalid region index");
586   return getOperands();
587 }
588 
589 void RegionIfOp::getSuccessorRegions(
590     Optional<unsigned> index, ArrayRef<Attribute> operands,
591     SmallVectorImpl<RegionSuccessor> &regions) {
592   // We always branch to the join region.
593   if (index.hasValue()) {
594     if (index.getValue() < 2)
595       regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs()));
596     else
597       regions.push_back(RegionSuccessor(getResults()));
598     return;
599   }
600 
601   // The then and else regions are the entry regions of this op.
602   regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs()));
603   regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs()));
604 }
605 
606 //===----------------------------------------------------------------------===//
607 // Dialect Registration
608 //===----------------------------------------------------------------------===//
609 
610 // Static initialization for Test dialect registration.
611 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
612 
613 #include "TestOpEnums.cpp.inc"
614 #include "TestOpStructs.cpp.inc"
615 #include "TestTypeInterfaces.cpp.inc"
616 
617 #define GET_OP_CLASSES
618 #include "TestOps.cpp.inc"
619