1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 #include "llvm/ADT/StringSwitch.h"
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // TestDialect Interfaces
23 //===----------------------------------------------------------------------===//
24 
25 namespace {
26 
27 // Test support for interacting with the AsmPrinter.
28 struct TestOpAsmInterface : public OpAsmDialectInterface {
29   using OpAsmDialectInterface::OpAsmDialectInterface;
30 
31   void getAsmResultNames(Operation *op,
32                          OpAsmSetValueNameFn setNameFn) const final {
33     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
34       setNameFn(asmOp, "result");
35   }
36 
37   void getAsmBlockArgumentNames(Block *block,
38                                 OpAsmSetValueNameFn setNameFn) const final {
39     auto op = block->getParentOp();
40     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
41     if (!arrayAttr)
42       return;
43     auto args = block->getArguments();
44     auto e = std::min(arrayAttr.size(), args.size());
45     for (unsigned i = 0; i < e; ++i) {
46       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
47         setNameFn(args[i], strAttr.getValue());
48     }
49   }
50 };
51 
52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
53   using OpFolderDialectInterface::OpFolderDialectInterface;
54 
55   /// Registered hook to check if the given region, which is attached to an
56   /// operation that is *not* isolated from above, should be used when
57   /// materializing constants.
58   bool shouldMaterializeInto(Region *region) const final {
59     // If this is a one region operation, then insert into it.
60     return isa<OneRegionOp>(region->getParentOp());
61   }
62 };
63 
64 /// This class defines the interface for handling inlining with standard
65 /// operations.
66 struct TestInlinerInterface : public DialectInlinerInterface {
67   using DialectInlinerInterface::DialectInlinerInterface;
68 
69   //===--------------------------------------------------------------------===//
70   // Analysis Hooks
71   //===--------------------------------------------------------------------===//
72 
73   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
74     // Inlining into test dialect regions is legal.
75     return true;
76   }
77   bool isLegalToInline(Operation *, Region *,
78                        BlockAndValueMapping &) const final {
79     return true;
80   }
81 
82   bool shouldAnalyzeRecursively(Operation *op) const final {
83     // Analyze recursively if this is not a functional region operation, it
84     // froms a separate functional scope.
85     return !isa<FunctionalRegionOp>(op);
86   }
87 
88   //===--------------------------------------------------------------------===//
89   // Transformation Hooks
90   //===--------------------------------------------------------------------===//
91 
92   /// Handle the given inlined terminator by replacing it with a new operation
93   /// as necessary.
94   void handleTerminator(Operation *op,
95                         ArrayRef<Value> valuesToRepl) const final {
96     // Only handle "test.return" here.
97     auto returnOp = dyn_cast<TestReturnOp>(op);
98     if (!returnOp)
99       return;
100 
101     // Replace the values directly with the return operands.
102     assert(returnOp.getNumOperands() == valuesToRepl.size());
103     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
104       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
105   }
106 
107   /// Attempt to materialize a conversion for a type mismatch between a call
108   /// from this dialect, and a callable region. This method should generate an
109   /// operation that takes 'input' as the only operand, and produces a single
110   /// result of 'resultType'. If a conversion can not be generated, nullptr
111   /// should be returned.
112   Operation *materializeCallConversion(OpBuilder &builder, Value input,
113                                        Type resultType,
114                                        Location conversionLoc) const final {
115     // Only allow conversion for i16/i32 types.
116     if (!(resultType.isSignlessInteger(16) ||
117           resultType.isSignlessInteger(32)) ||
118         !(input.getType().isSignlessInteger(16) ||
119           input.getType().isSignlessInteger(32)))
120       return nullptr;
121     return builder.create<TestCastOp>(conversionLoc, resultType, input);
122   }
123 };
124 } // end anonymous namespace
125 
126 //===----------------------------------------------------------------------===//
127 // TestDialect
128 //===----------------------------------------------------------------------===//
129 
130 TestDialect::TestDialect(MLIRContext *context)
131     : Dialect(getDialectNamespace(), context) {
132   addOperations<
133 #define GET_OP_LIST
134 #include "TestOps.cpp.inc"
135       >();
136   addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
137                 TestInlinerInterface>();
138   allowUnknownOperations();
139 }
140 
141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
142                                                     NamedAttribute namedAttr) {
143   if (namedAttr.first == "test.invalid_attr")
144     return op->emitError() << "invalid to use 'test.invalid_attr'";
145   return success();
146 }
147 
148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
149                                                     unsigned regionIndex,
150                                                     unsigned argIndex,
151                                                     NamedAttribute namedAttr) {
152   if (namedAttr.first == "test.invalid_attr")
153     return op->emitError() << "invalid to use 'test.invalid_attr'";
154   return success();
155 }
156 
157 LogicalResult
158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
159                                          unsigned resultIndex,
160                                          NamedAttribute namedAttr) {
161   if (namedAttr.first == "test.invalid_attr")
162     return op->emitError() << "invalid to use 'test.invalid_attr'";
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // TestBranchOp
168 //===----------------------------------------------------------------------===//
169 
170 Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
171   assert(index == 0 && "invalid successor index");
172   return getOperands();
173 }
174 
175 bool TestBranchOp::canEraseSuccessorOperand() { return true; }
176 
177 //===----------------------------------------------------------------------===//
178 // Test IsolatedRegionOp - parse passthrough region arguments.
179 //===----------------------------------------------------------------------===//
180 
181 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
182                                          OperationState &result) {
183   OpAsmParser::OperandType argInfo;
184   Type argType = parser.getBuilder().getIndexType();
185 
186   // Parse the input operand.
187   if (parser.parseOperand(argInfo) ||
188       parser.resolveOperand(argInfo, argType, result.operands))
189     return failure();
190 
191   // Parse the body region, and reuse the operand info as the argument info.
192   Region *body = result.addRegion();
193   return parser.parseRegion(*body, argInfo, argType,
194                             /*enableNameShadowing=*/true);
195 }
196 
197 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
198   p << "test.isolated_region ";
199   p.printOperand(op.getOperand());
200   p.shadowRegionArgs(op.region(), op.getOperand());
201   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // Test parser.
206 //===----------------------------------------------------------------------===//
207 
208 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
209                                          OperationState &result) {
210   StringRef keyword;
211   if (parser.parseKeyword(&keyword))
212     return failure();
213   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
214   return success();
215 }
216 
217 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
218   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
223 
224 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
225                                          OperationState &result) {
226   if (parser.parseKeyword("wraps"))
227     return failure();
228 
229   // Parse the wrapped op in a region
230   Region &body = *result.addRegion();
231   body.push_back(new Block);
232   Block &block = body.back();
233   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
234   if (!wrapped_op)
235     return failure();
236 
237   // Create a return terminator in the inner region, pass as operand to the
238   // terminator the returned values from the wrapped operation.
239   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
240   OpBuilder builder(parser.getBuilder().getContext());
241   builder.setInsertionPointToEnd(&block);
242   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
243 
244   // Get the results type for the wrapping op from the terminator operands.
245   Operation &return_op = body.back().back();
246   result.types.append(return_op.operand_type_begin(),
247                       return_op.operand_type_end());
248 
249   // Use the location of the wrapped op for the "test.wrapping_region" op.
250   result.location = wrapped_op->getLoc();
251 
252   return success();
253 }
254 
255 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
256   p << op.getOperationName() << " wraps ";
257   p.printGenericOp(&op.region().front().front());
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // Test PolyForOp - parse list of region arguments.
262 //===----------------------------------------------------------------------===//
263 
264 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
265   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
266   // Parse list of region arguments without a delimiter.
267   if (parser.parseRegionArgumentList(ivsInfo))
268     return failure();
269 
270   // Parse the body region.
271   Region *body = result.addRegion();
272   auto &builder = parser.getBuilder();
273   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
274   return parser.parseRegion(*body, ivsInfo, argTypes);
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // Test removing op with inner ops.
279 //===----------------------------------------------------------------------===//
280 
281 namespace {
282 struct TestRemoveOpWithInnerOps
283     : public OpRewritePattern<TestOpWithRegionPattern> {
284   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
285 
286   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
287                                 PatternRewriter &rewriter) const override {
288     rewriter.eraseOp(op);
289     return success();
290   }
291 };
292 } // end anonymous namespace
293 
294 void TestOpWithRegionPattern::getCanonicalizationPatterns(
295     OwningRewritePatternList &results, MLIRContext *context) {
296   results.insert<TestRemoveOpWithInnerOps>(context);
297 }
298 
299 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
300   return operand();
301 }
302 
303 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
304     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
305   for (Value input : this->operands()) {
306     results.push_back(input);
307   }
308   return success();
309 }
310 
311 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
312     MLIRContext *, Optional<Location> location, ValueRange operands,
313     ArrayRef<NamedAttribute> attributes, RegionRange regions,
314     SmallVectorImpl<Type> &inferredReturnTypes) {
315   if (operands[0].getType() != operands[1].getType()) {
316     return emitOptionalError(location, "operand type mismatch ",
317                              operands[0].getType(), " vs ",
318                              operands[1].getType());
319   }
320   inferredReturnTypes.assign({operands[0].getType()});
321   return success();
322 }
323 
324 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
325     MLIRContext *context, Optional<Location> location, ValueRange operands,
326     ArrayRef<NamedAttribute> attributes, RegionRange regions,
327     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
328   // Create return type consisting of the last element of the first operand.
329   auto operandType = *operands.getTypes().begin();
330   auto sval = operandType.dyn_cast<ShapedType>();
331   if (!sval) {
332     return emitOptionalError(location, "only shaped type operands allowed");
333   }
334   int64_t dim =
335       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
336   auto type = IntegerType::get(17, context);
337   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
338   return success();
339 }
340 
341 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
342     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
343   shapes = SmallVector<Value, 1>{
344       builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
345   return success();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // Test SideEffect interfaces
350 //===----------------------------------------------------------------------===//
351 
352 namespace {
353 /// A test resource for side effects.
354 struct TestResource : public SideEffects::Resource::Base<TestResource> {
355   StringRef getName() final { return "<Test>"; }
356 };
357 } // end anonymous namespace
358 
359 void SideEffectOp::getEffects(
360     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
361   // Check for an effects attribute on the op instance.
362   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
363   if (!effectsAttr)
364     return;
365 
366   // If there is one, it is an array of dictionary attributes that hold
367   // information on the effects of this operation.
368   for (Attribute element : effectsAttr) {
369     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
370 
371     // Get the specific memory effect.
372     MemoryEffects::Effect *effect =
373         llvm::StringSwitch<MemoryEffects::Effect *>(
374             effectElement.get("effect").cast<StringAttr>().getValue())
375             .Case("allocate", MemoryEffects::Allocate::get())
376             .Case("free", MemoryEffects::Free::get())
377             .Case("read", MemoryEffects::Read::get())
378             .Case("write", MemoryEffects::Write::get());
379 
380     // Check for a result to affect.
381     Value value;
382     if (effectElement.get("on_result"))
383       value = getResult();
384 
385     // Check for a non-default resource to use.
386     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
387     if (effectElement.get("test_resource"))
388       resource = TestResource::get();
389 
390     effects.emplace_back(effect, value, resource);
391   }
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // StringAttrPrettyNameOp
396 //===----------------------------------------------------------------------===//
397 
398 // This op has fancy handling of its SSA result name.
399 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
400                                                OperationState &result) {
401   // Add the result types.
402   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
403     result.addTypes(parser.getBuilder().getIntegerType(32));
404 
405   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
406     return failure();
407 
408   // If the attribute dictionary contains no 'names' attribute, infer it from
409   // the SSA name (if specified).
410   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
411     return attr.first == "names";
412   });
413 
414   // If there was no name specified, check to see if there was a useful name
415   // specified in the asm file.
416   if (hadNames || parser.getNumResults() == 0)
417     return success();
418 
419   SmallVector<StringRef, 4> names;
420   auto *context = result.getContext();
421 
422   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
423     auto resultName = parser.getResultName(i);
424     StringRef nameStr;
425     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
426       nameStr = resultName.first;
427 
428     names.push_back(nameStr);
429   }
430 
431   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
432   result.attributes.push_back({Identifier::get("names", context), namesAttr});
433   return success();
434 }
435 
436 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
437   p << "test.string_attr_pretty_name";
438 
439   // Note that we only need to print the "name" attribute if the asmprinter
440   // result name disagrees with it.  This can happen in strange cases, e.g.
441   // when there are conflicts.
442   bool namesDisagree = op.names().size() != op.getNumResults();
443 
444   SmallString<32> resultNameStr;
445   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
446     resultNameStr.clear();
447     llvm::raw_svector_ostream tmpStream(resultNameStr);
448     p.printOperand(op.getResult(i), tmpStream);
449 
450     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
451     if (!expectedName ||
452         tmpStream.str().drop_front() != expectedName.getValue()) {
453       namesDisagree = true;
454     }
455   }
456 
457   if (namesDisagree)
458     p.printOptionalAttrDictWithKeyword(op.getAttrs());
459   else
460     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
461 }
462 
463 // We set the SSA name in the asm syntax to the contents of the name
464 // attribute.
465 void StringAttrPrettyNameOp::getAsmResultNames(
466     function_ref<void(Value, StringRef)> setNameFn) {
467 
468   auto value = names();
469   for (size_t i = 0, e = value.size(); i != e; ++i)
470     if (auto str = value[i].dyn_cast<StringAttr>())
471       if (!str.getValue().empty())
472         setNameFn(getResult(i), str.getValue());
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Dialect Registration
477 //===----------------------------------------------------------------------===//
478 
479 // Static initialization for Test dialect registration.
480 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
481 
482 #include "TestOpEnums.cpp.inc"
483 
484 #define GET_OP_CLASSES
485 #include "TestOps.cpp.inc"
486