1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 #include "llvm/ADT/StringSwitch.h"
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // TestDialect Interfaces
23 //===----------------------------------------------------------------------===//
24 
25 namespace {
26 
27 // Test support for interacting with the AsmPrinter.
28 struct TestOpAsmInterface : public OpAsmDialectInterface {
29   using OpAsmDialectInterface::OpAsmDialectInterface;
30 
31   void getAsmResultNames(Operation *op,
32                          OpAsmSetValueNameFn setNameFn) const final {
33     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
34       setNameFn(asmOp, "result");
35   }
36 
37   void getAsmBlockArgumentNames(Block *block,
38                                 OpAsmSetValueNameFn setNameFn) const final {
39     auto op = block->getParentOp();
40     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
41     if (!arrayAttr)
42       return;
43     auto args = block->getArguments();
44     auto e = std::min(arrayAttr.size(), args.size());
45     for (unsigned i = 0; i < e; ++i) {
46       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
47         setNameFn(args[i], strAttr.getValue());
48     }
49   }
50 };
51 
52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
53   using OpFolderDialectInterface::OpFolderDialectInterface;
54 
55   /// Registered hook to check if the given region, which is attached to an
56   /// operation that is *not* isolated from above, should be used when
57   /// materializing constants.
58   bool shouldMaterializeInto(Region *region) const final {
59     // If this is a one region operation, then insert into it.
60     return isa<OneRegionOp>(region->getParentOp());
61   }
62 };
63 
64 /// This class defines the interface for handling inlining with standard
65 /// operations.
66 struct TestInlinerInterface : public DialectInlinerInterface {
67   using DialectInlinerInterface::DialectInlinerInterface;
68 
69   //===--------------------------------------------------------------------===//
70   // Analysis Hooks
71   //===--------------------------------------------------------------------===//
72 
73   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
74     // Inlining into test dialect regions is legal.
75     return true;
76   }
77   bool isLegalToInline(Operation *, Region *,
78                        BlockAndValueMapping &) const final {
79     return true;
80   }
81 
82   bool shouldAnalyzeRecursively(Operation *op) const final {
83     // Analyze recursively if this is not a functional region operation, it
84     // froms a separate functional scope.
85     return !isa<FunctionalRegionOp>(op);
86   }
87 
88   //===--------------------------------------------------------------------===//
89   // Transformation Hooks
90   //===--------------------------------------------------------------------===//
91 
92   /// Handle the given inlined terminator by replacing it with a new operation
93   /// as necessary.
94   void handleTerminator(Operation *op,
95                         ArrayRef<Value> valuesToRepl) const final {
96     // Only handle "test.return" here.
97     auto returnOp = dyn_cast<TestReturnOp>(op);
98     if (!returnOp)
99       return;
100 
101     // Replace the values directly with the return operands.
102     assert(returnOp.getNumOperands() == valuesToRepl.size());
103     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
104       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
105   }
106 
107   /// Attempt to materialize a conversion for a type mismatch between a call
108   /// from this dialect, and a callable region. This method should generate an
109   /// operation that takes 'input' as the only operand, and produces a single
110   /// result of 'resultType'. If a conversion can not be generated, nullptr
111   /// should be returned.
112   Operation *materializeCallConversion(OpBuilder &builder, Value input,
113                                        Type resultType,
114                                        Location conversionLoc) const final {
115     // Only allow conversion for i16/i32 types.
116     if (!(resultType.isSignlessInteger(16) ||
117           resultType.isSignlessInteger(32)) ||
118         !(input.getType().isSignlessInteger(16) ||
119           input.getType().isSignlessInteger(32)))
120       return nullptr;
121     return builder.create<TestCastOp>(conversionLoc, resultType, input);
122   }
123 };
124 } // end anonymous namespace
125 
126 //===----------------------------------------------------------------------===//
127 // TestDialect
128 //===----------------------------------------------------------------------===//
129 
130 TestDialect::TestDialect(MLIRContext *context)
131     : Dialect(getDialectNamespace(), context) {
132   addOperations<
133 #define GET_OP_LIST
134 #include "TestOps.cpp.inc"
135       >();
136   addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
137                 TestInlinerInterface>();
138   allowUnknownOperations();
139 }
140 
141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
142                                                     NamedAttribute namedAttr) {
143   if (namedAttr.first == "test.invalid_attr")
144     return op->emitError() << "invalid to use 'test.invalid_attr'";
145   return success();
146 }
147 
148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
149                                                     unsigned regionIndex,
150                                                     unsigned argIndex,
151                                                     NamedAttribute namedAttr) {
152   if (namedAttr.first == "test.invalid_attr")
153     return op->emitError() << "invalid to use 'test.invalid_attr'";
154   return success();
155 }
156 
157 LogicalResult
158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
159                                          unsigned resultIndex,
160                                          NamedAttribute namedAttr) {
161   if (namedAttr.first == "test.invalid_attr")
162     return op->emitError() << "invalid to use 'test.invalid_attr'";
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // TestBranchOp
168 //===----------------------------------------------------------------------===//
169 
170 Optional<MutableOperandRange>
171 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
172   assert(index == 0 && "invalid successor index");
173   return targetOperandsMutable();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // TestFoldToCallOp
178 //===----------------------------------------------------------------------===//
179 
180 namespace {
181 struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
182   using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
183 
184   LogicalResult matchAndRewrite(FoldToCallOp op,
185                                 PatternRewriter &rewriter) const override {
186     rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
187                                         ValueRange());
188     return success();
189   }
190 };
191 } // end anonymous namespace
192 
193 void FoldToCallOp::getCanonicalizationPatterns(
194     OwningRewritePatternList &results, MLIRContext *context) {
195   results.insert<FoldToCallOpPattern>(context);
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // Test IsolatedRegionOp - parse passthrough region arguments.
200 //===----------------------------------------------------------------------===//
201 
202 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
203                                          OperationState &result) {
204   OpAsmParser::OperandType argInfo;
205   Type argType = parser.getBuilder().getIndexType();
206 
207   // Parse the input operand.
208   if (parser.parseOperand(argInfo) ||
209       parser.resolveOperand(argInfo, argType, result.operands))
210     return failure();
211 
212   // Parse the body region, and reuse the operand info as the argument info.
213   Region *body = result.addRegion();
214   return parser.parseRegion(*body, argInfo, argType,
215                             /*enableNameShadowing=*/true);
216 }
217 
218 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
219   p << "test.isolated_region ";
220   p.printOperand(op.getOperand());
221   p.shadowRegionArgs(op.region(), op.getOperand());
222   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
223 }
224 
225 //===----------------------------------------------------------------------===//
226 // Test AffineScopeOp
227 //===----------------------------------------------------------------------===//
228 
229 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
230                                       OperationState &result) {
231   // Parse the body region, and reuse the operand info as the argument info.
232   Region *body = result.addRegion();
233   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
234 }
235 
236 static void print(OpAsmPrinter &p, AffineScopeOp op) {
237   p << "test.affine_scope ";
238   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // Test parser.
243 //===----------------------------------------------------------------------===//
244 
245 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
246                                          OperationState &result) {
247   StringRef keyword;
248   if (parser.parseKeyword(&keyword))
249     return failure();
250   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
251   return success();
252 }
253 
254 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
255   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
256 }
257 
258 //===----------------------------------------------------------------------===//
259 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
260 
261 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
262                                          OperationState &result) {
263   if (parser.parseKeyword("wraps"))
264     return failure();
265 
266   // Parse the wrapped op in a region
267   Region &body = *result.addRegion();
268   body.push_back(new Block);
269   Block &block = body.back();
270   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
271   if (!wrapped_op)
272     return failure();
273 
274   // Create a return terminator in the inner region, pass as operand to the
275   // terminator the returned values from the wrapped operation.
276   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
277   OpBuilder builder(parser.getBuilder().getContext());
278   builder.setInsertionPointToEnd(&block);
279   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
280 
281   // Get the results type for the wrapping op from the terminator operands.
282   Operation &return_op = body.back().back();
283   result.types.append(return_op.operand_type_begin(),
284                       return_op.operand_type_end());
285 
286   // Use the location of the wrapped op for the "test.wrapping_region" op.
287   result.location = wrapped_op->getLoc();
288 
289   return success();
290 }
291 
292 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
293   p << op.getOperationName() << " wraps ";
294   p.printGenericOp(&op.region().front().front());
295 }
296 
297 //===----------------------------------------------------------------------===//
298 // Test PolyForOp - parse list of region arguments.
299 //===----------------------------------------------------------------------===//
300 
301 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
302   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
303   // Parse list of region arguments without a delimiter.
304   if (parser.parseRegionArgumentList(ivsInfo))
305     return failure();
306 
307   // Parse the body region.
308   Region *body = result.addRegion();
309   auto &builder = parser.getBuilder();
310   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
311   return parser.parseRegion(*body, ivsInfo, argTypes);
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Test removing op with inner ops.
316 //===----------------------------------------------------------------------===//
317 
318 namespace {
319 struct TestRemoveOpWithInnerOps
320     : public OpRewritePattern<TestOpWithRegionPattern> {
321   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
322 
323   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
324                                 PatternRewriter &rewriter) const override {
325     rewriter.eraseOp(op);
326     return success();
327   }
328 };
329 } // end anonymous namespace
330 
331 void TestOpWithRegionPattern::getCanonicalizationPatterns(
332     OwningRewritePatternList &results, MLIRContext *context) {
333   results.insert<TestRemoveOpWithInnerOps>(context);
334 }
335 
336 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
337   return operand();
338 }
339 
340 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
341     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
342   for (Value input : this->operands()) {
343     results.push_back(input);
344   }
345   return success();
346 }
347 
348 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
349   assert(operands.size() == 1);
350   if (operands.front()) {
351     setAttr("attr", operands.front());
352     return getResult();
353   }
354   return {};
355 }
356 
357 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
358     MLIRContext *, Optional<Location> location, ValueRange operands,
359     DictionaryAttr attributes, RegionRange regions,
360     SmallVectorImpl<Type> &inferredReturnTypes) {
361   if (operands[0].getType() != operands[1].getType()) {
362     return emitOptionalError(location, "operand type mismatch ",
363                              operands[0].getType(), " vs ",
364                              operands[1].getType());
365   }
366   inferredReturnTypes.assign({operands[0].getType()});
367   return success();
368 }
369 
370 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
371     MLIRContext *context, Optional<Location> location, ValueRange operands,
372     DictionaryAttr attributes, RegionRange regions,
373     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
374   // Create return type consisting of the last element of the first operand.
375   auto operandType = *operands.getTypes().begin();
376   auto sval = operandType.dyn_cast<ShapedType>();
377   if (!sval) {
378     return emitOptionalError(location, "only shaped type operands allowed");
379   }
380   int64_t dim =
381       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
382   auto type = IntegerType::get(17, context);
383   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
384   return success();
385 }
386 
387 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
388     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
389   shapes = SmallVector<Value, 1>{
390       builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
391   return success();
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // Test SideEffect interfaces
396 //===----------------------------------------------------------------------===//
397 
398 namespace {
399 /// A test resource for side effects.
400 struct TestResource : public SideEffects::Resource::Base<TestResource> {
401   StringRef getName() final { return "<Test>"; }
402 };
403 } // end anonymous namespace
404 
405 void SideEffectOp::getEffects(
406     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
407   // Check for an effects attribute on the op instance.
408   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
409   if (!effectsAttr)
410     return;
411 
412   // If there is one, it is an array of dictionary attributes that hold
413   // information on the effects of this operation.
414   for (Attribute element : effectsAttr) {
415     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
416 
417     // Get the specific memory effect.
418     MemoryEffects::Effect *effect =
419         llvm::StringSwitch<MemoryEffects::Effect *>(
420             effectElement.get("effect").cast<StringAttr>().getValue())
421             .Case("allocate", MemoryEffects::Allocate::get())
422             .Case("free", MemoryEffects::Free::get())
423             .Case("read", MemoryEffects::Read::get())
424             .Case("write", MemoryEffects::Write::get());
425 
426     // Check for a result to affect.
427     Value value;
428     if (effectElement.get("on_result"))
429       value = getResult();
430 
431     // Check for a non-default resource to use.
432     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
433     if (effectElement.get("test_resource"))
434       resource = TestResource::get();
435 
436     effects.emplace_back(effect, value, resource);
437   }
438 }
439 
440 //===----------------------------------------------------------------------===//
441 // StringAttrPrettyNameOp
442 //===----------------------------------------------------------------------===//
443 
444 // This op has fancy handling of its SSA result name.
445 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
446                                                OperationState &result) {
447   // Add the result types.
448   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
449     result.addTypes(parser.getBuilder().getIntegerType(32));
450 
451   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
452     return failure();
453 
454   // If the attribute dictionary contains no 'names' attribute, infer it from
455   // the SSA name (if specified).
456   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
457     return attr.first == "names";
458   });
459 
460   // If there was no name specified, check to see if there was a useful name
461   // specified in the asm file.
462   if (hadNames || parser.getNumResults() == 0)
463     return success();
464 
465   SmallVector<StringRef, 4> names;
466   auto *context = result.getContext();
467 
468   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
469     auto resultName = parser.getResultName(i);
470     StringRef nameStr;
471     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
472       nameStr = resultName.first;
473 
474     names.push_back(nameStr);
475   }
476 
477   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
478   result.attributes.push_back({Identifier::get("names", context), namesAttr});
479   return success();
480 }
481 
482 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
483   p << "test.string_attr_pretty_name";
484 
485   // Note that we only need to print the "name" attribute if the asmprinter
486   // result name disagrees with it.  This can happen in strange cases, e.g.
487   // when there are conflicts.
488   bool namesDisagree = op.names().size() != op.getNumResults();
489 
490   SmallString<32> resultNameStr;
491   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
492     resultNameStr.clear();
493     llvm::raw_svector_ostream tmpStream(resultNameStr);
494     p.printOperand(op.getResult(i), tmpStream);
495 
496     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
497     if (!expectedName ||
498         tmpStream.str().drop_front() != expectedName.getValue()) {
499       namesDisagree = true;
500     }
501   }
502 
503   if (namesDisagree)
504     p.printOptionalAttrDictWithKeyword(op.getAttrs());
505   else
506     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
507 }
508 
509 // We set the SSA name in the asm syntax to the contents of the name
510 // attribute.
511 void StringAttrPrettyNameOp::getAsmResultNames(
512     function_ref<void(Value, StringRef)> setNameFn) {
513 
514   auto value = names();
515   for (size_t i = 0, e = value.size(); i != e; ++i)
516     if (auto str = value[i].dyn_cast<StringAttr>())
517       if (!str.getValue().empty())
518         setNameFn(getResult(i), str.getValue());
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // Dialect Registration
523 //===----------------------------------------------------------------------===//
524 
525 // Static initialization for Test dialect registration.
526 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
527 
528 #include "TestOpEnums.cpp.inc"
529 #include "TestOpStructs.cpp.inc"
530 
531 #define GET_OP_CLASSES
532 #include "TestOps.cpp.inc"
533