1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 #include "llvm/ADT/StringSwitch.h"
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // TestDialect Interfaces
23 //===----------------------------------------------------------------------===//
24 
25 namespace {
26 
27 // Test support for interacting with the AsmPrinter.
28 struct TestOpAsmInterface : public OpAsmDialectInterface {
29   using OpAsmDialectInterface::OpAsmDialectInterface;
30 
31   void getAsmResultNames(Operation *op,
32                          OpAsmSetValueNameFn setNameFn) const final {
33     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
34       setNameFn(asmOp, "result");
35   }
36 
37   void getAsmBlockArgumentNames(Block *block,
38                                 OpAsmSetValueNameFn setNameFn) const final {
39     auto op = block->getParentOp();
40     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
41     if (!arrayAttr)
42       return;
43     auto args = block->getArguments();
44     auto e = std::min(arrayAttr.size(), args.size());
45     for (unsigned i = 0; i < e; ++i) {
46       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
47         setNameFn(args[i], strAttr.getValue());
48     }
49   }
50 };
51 
52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
53   using OpFolderDialectInterface::OpFolderDialectInterface;
54 
55   /// Registered hook to check if the given region, which is attached to an
56   /// operation that is *not* isolated from above, should be used when
57   /// materializing constants.
58   bool shouldMaterializeInto(Region *region) const final {
59     // If this is a one region operation, then insert into it.
60     return isa<OneRegionOp>(region->getParentOp());
61   }
62 };
63 
64 /// This class defines the interface for handling inlining with standard
65 /// operations.
66 struct TestInlinerInterface : public DialectInlinerInterface {
67   using DialectInlinerInterface::DialectInlinerInterface;
68 
69   //===--------------------------------------------------------------------===//
70   // Analysis Hooks
71   //===--------------------------------------------------------------------===//
72 
73   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
74     // Inlining into test dialect regions is legal.
75     return true;
76   }
77   bool isLegalToInline(Operation *, Region *,
78                        BlockAndValueMapping &) const final {
79     return true;
80   }
81 
82   bool shouldAnalyzeRecursively(Operation *op) const final {
83     // Analyze recursively if this is not a functional region operation, it
84     // froms a separate functional scope.
85     return !isa<FunctionalRegionOp>(op);
86   }
87 
88   //===--------------------------------------------------------------------===//
89   // Transformation Hooks
90   //===--------------------------------------------------------------------===//
91 
92   /// Handle the given inlined terminator by replacing it with a new operation
93   /// as necessary.
94   void handleTerminator(Operation *op,
95                         ArrayRef<Value> valuesToRepl) const final {
96     // Only handle "test.return" here.
97     auto returnOp = dyn_cast<TestReturnOp>(op);
98     if (!returnOp)
99       return;
100 
101     // Replace the values directly with the return operands.
102     assert(returnOp.getNumOperands() == valuesToRepl.size());
103     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
104       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
105   }
106 
107   /// Attempt to materialize a conversion for a type mismatch between a call
108   /// from this dialect, and a callable region. This method should generate an
109   /// operation that takes 'input' as the only operand, and produces a single
110   /// result of 'resultType'. If a conversion can not be generated, nullptr
111   /// should be returned.
112   Operation *materializeCallConversion(OpBuilder &builder, Value input,
113                                        Type resultType,
114                                        Location conversionLoc) const final {
115     // Only allow conversion for i16/i32 types.
116     if (!(resultType.isSignlessInteger(16) ||
117           resultType.isSignlessInteger(32)) ||
118         !(input.getType().isSignlessInteger(16) ||
119           input.getType().isSignlessInteger(32)))
120       return nullptr;
121     return builder.create<TestCastOp>(conversionLoc, resultType, input);
122   }
123 };
124 } // end anonymous namespace
125 
126 //===----------------------------------------------------------------------===//
127 // TestDialect
128 //===----------------------------------------------------------------------===//
129 
130 TestDialect::TestDialect(MLIRContext *context)
131     : Dialect(getDialectNamespace(), context) {
132   addOperations<
133 #define GET_OP_LIST
134 #include "TestOps.cpp.inc"
135       >();
136   addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
137                 TestInlinerInterface>();
138   allowUnknownOperations();
139 }
140 
141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
142                                                     NamedAttribute namedAttr) {
143   if (namedAttr.first == "test.invalid_attr")
144     return op->emitError() << "invalid to use 'test.invalid_attr'";
145   return success();
146 }
147 
148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
149                                                     unsigned regionIndex,
150                                                     unsigned argIndex,
151                                                     NamedAttribute namedAttr) {
152   if (namedAttr.first == "test.invalid_attr")
153     return op->emitError() << "invalid to use 'test.invalid_attr'";
154   return success();
155 }
156 
157 LogicalResult
158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
159                                          unsigned resultIndex,
160                                          NamedAttribute namedAttr) {
161   if (namedAttr.first == "test.invalid_attr")
162     return op->emitError() << "invalid to use 'test.invalid_attr'";
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // TestBranchOp
168 //===----------------------------------------------------------------------===//
169 
170 Optional<OperandRange> TestBranchOp::getSuccessorOperands(unsigned index) {
171   assert(index == 0 && "invalid successor index");
172   return getOperands();
173 }
174 
175 bool TestBranchOp::canEraseSuccessorOperand() { return true; }
176 
177 //===----------------------------------------------------------------------===//
178 // Test IsolatedRegionOp - parse passthrough region arguments.
179 //===----------------------------------------------------------------------===//
180 
181 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
182                                          OperationState &result) {
183   OpAsmParser::OperandType argInfo;
184   Type argType = parser.getBuilder().getIndexType();
185 
186   // Parse the input operand.
187   if (parser.parseOperand(argInfo) ||
188       parser.resolveOperand(argInfo, argType, result.operands))
189     return failure();
190 
191   // Parse the body region, and reuse the operand info as the argument info.
192   Region *body = result.addRegion();
193   return parser.parseRegion(*body, argInfo, argType,
194                             /*enableNameShadowing=*/true);
195 }
196 
197 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
198   p << "test.isolated_region ";
199   p.printOperand(op.getOperand());
200   p.shadowRegionArgs(op.region(), op.getOperand());
201   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // Test PolyhedralScopeOp
206 //===----------------------------------------------------------------------===//
207 
208 static ParseResult parsePolyhedralScopeOp(OpAsmParser &parser,
209                                           OperationState &result) {
210   // Parse the body region, and reuse the operand info as the argument info.
211   Region *body = result.addRegion();
212   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
213 }
214 
215 static void print(OpAsmPrinter &p, PolyhedralScopeOp op) {
216   p << "test.polyhedral_scope ";
217   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
218 }
219 
220 //===----------------------------------------------------------------------===//
221 // Test parser.
222 //===----------------------------------------------------------------------===//
223 
224 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
225                                          OperationState &result) {
226   StringRef keyword;
227   if (parser.parseKeyword(&keyword))
228     return failure();
229   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
230   return success();
231 }
232 
233 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
234   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
239 
240 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
241                                          OperationState &result) {
242   if (parser.parseKeyword("wraps"))
243     return failure();
244 
245   // Parse the wrapped op in a region
246   Region &body = *result.addRegion();
247   body.push_back(new Block);
248   Block &block = body.back();
249   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
250   if (!wrapped_op)
251     return failure();
252 
253   // Create a return terminator in the inner region, pass as operand to the
254   // terminator the returned values from the wrapped operation.
255   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
256   OpBuilder builder(parser.getBuilder().getContext());
257   builder.setInsertionPointToEnd(&block);
258   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
259 
260   // Get the results type for the wrapping op from the terminator operands.
261   Operation &return_op = body.back().back();
262   result.types.append(return_op.operand_type_begin(),
263                       return_op.operand_type_end());
264 
265   // Use the location of the wrapped op for the "test.wrapping_region" op.
266   result.location = wrapped_op->getLoc();
267 
268   return success();
269 }
270 
271 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
272   p << op.getOperationName() << " wraps ";
273   p.printGenericOp(&op.region().front().front());
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // Test PolyForOp - parse list of region arguments.
278 //===----------------------------------------------------------------------===//
279 
280 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
281   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
282   // Parse list of region arguments without a delimiter.
283   if (parser.parseRegionArgumentList(ivsInfo))
284     return failure();
285 
286   // Parse the body region.
287   Region *body = result.addRegion();
288   auto &builder = parser.getBuilder();
289   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
290   return parser.parseRegion(*body, ivsInfo, argTypes);
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // Test removing op with inner ops.
295 //===----------------------------------------------------------------------===//
296 
297 namespace {
298 struct TestRemoveOpWithInnerOps
299     : public OpRewritePattern<TestOpWithRegionPattern> {
300   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
301 
302   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
303                                 PatternRewriter &rewriter) const override {
304     rewriter.eraseOp(op);
305     return success();
306   }
307 };
308 } // end anonymous namespace
309 
310 void TestOpWithRegionPattern::getCanonicalizationPatterns(
311     OwningRewritePatternList &results, MLIRContext *context) {
312   results.insert<TestRemoveOpWithInnerOps>(context);
313 }
314 
315 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
316   return operand();
317 }
318 
319 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
320     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
321   for (Value input : this->operands()) {
322     results.push_back(input);
323   }
324   return success();
325 }
326 
327 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
328     MLIRContext *, Optional<Location> location, ValueRange operands,
329     ArrayRef<NamedAttribute> attributes, RegionRange regions,
330     SmallVectorImpl<Type> &inferredReturnTypes) {
331   if (operands[0].getType() != operands[1].getType()) {
332     return emitOptionalError(location, "operand type mismatch ",
333                              operands[0].getType(), " vs ",
334                              operands[1].getType());
335   }
336   inferredReturnTypes.assign({operands[0].getType()});
337   return success();
338 }
339 
340 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
341     MLIRContext *context, Optional<Location> location, ValueRange operands,
342     ArrayRef<NamedAttribute> attributes, RegionRange regions,
343     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
344   // Create return type consisting of the last element of the first operand.
345   auto operandType = *operands.getTypes().begin();
346   auto sval = operandType.dyn_cast<ShapedType>();
347   if (!sval) {
348     return emitOptionalError(location, "only shaped type operands allowed");
349   }
350   int64_t dim =
351       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
352   auto type = IntegerType::get(17, context);
353   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
354   return success();
355 }
356 
357 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
358     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
359   shapes = SmallVector<Value, 1>{
360       builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
361   return success();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // Test SideEffect interfaces
366 //===----------------------------------------------------------------------===//
367 
368 namespace {
369 /// A test resource for side effects.
370 struct TestResource : public SideEffects::Resource::Base<TestResource> {
371   StringRef getName() final { return "<Test>"; }
372 };
373 } // end anonymous namespace
374 
375 void SideEffectOp::getEffects(
376     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
377   // Check for an effects attribute on the op instance.
378   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
379   if (!effectsAttr)
380     return;
381 
382   // If there is one, it is an array of dictionary attributes that hold
383   // information on the effects of this operation.
384   for (Attribute element : effectsAttr) {
385     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
386 
387     // Get the specific memory effect.
388     MemoryEffects::Effect *effect =
389         llvm::StringSwitch<MemoryEffects::Effect *>(
390             effectElement.get("effect").cast<StringAttr>().getValue())
391             .Case("allocate", MemoryEffects::Allocate::get())
392             .Case("free", MemoryEffects::Free::get())
393             .Case("read", MemoryEffects::Read::get())
394             .Case("write", MemoryEffects::Write::get());
395 
396     // Check for a result to affect.
397     Value value;
398     if (effectElement.get("on_result"))
399       value = getResult();
400 
401     // Check for a non-default resource to use.
402     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
403     if (effectElement.get("test_resource"))
404       resource = TestResource::get();
405 
406     effects.emplace_back(effect, value, resource);
407   }
408 }
409 
410 //===----------------------------------------------------------------------===//
411 // StringAttrPrettyNameOp
412 //===----------------------------------------------------------------------===//
413 
414 // This op has fancy handling of its SSA result name.
415 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
416                                                OperationState &result) {
417   // Add the result types.
418   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
419     result.addTypes(parser.getBuilder().getIntegerType(32));
420 
421   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
422     return failure();
423 
424   // If the attribute dictionary contains no 'names' attribute, infer it from
425   // the SSA name (if specified).
426   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
427     return attr.first == "names";
428   });
429 
430   // If there was no name specified, check to see if there was a useful name
431   // specified in the asm file.
432   if (hadNames || parser.getNumResults() == 0)
433     return success();
434 
435   SmallVector<StringRef, 4> names;
436   auto *context = result.getContext();
437 
438   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
439     auto resultName = parser.getResultName(i);
440     StringRef nameStr;
441     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
442       nameStr = resultName.first;
443 
444     names.push_back(nameStr);
445   }
446 
447   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
448   result.attributes.push_back({Identifier::get("names", context), namesAttr});
449   return success();
450 }
451 
452 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
453   p << "test.string_attr_pretty_name";
454 
455   // Note that we only need to print the "name" attribute if the asmprinter
456   // result name disagrees with it.  This can happen in strange cases, e.g.
457   // when there are conflicts.
458   bool namesDisagree = op.names().size() != op.getNumResults();
459 
460   SmallString<32> resultNameStr;
461   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
462     resultNameStr.clear();
463     llvm::raw_svector_ostream tmpStream(resultNameStr);
464     p.printOperand(op.getResult(i), tmpStream);
465 
466     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
467     if (!expectedName ||
468         tmpStream.str().drop_front() != expectedName.getValue()) {
469       namesDisagree = true;
470     }
471   }
472 
473   if (namesDisagree)
474     p.printOptionalAttrDictWithKeyword(op.getAttrs());
475   else
476     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
477 }
478 
479 // We set the SSA name in the asm syntax to the contents of the name
480 // attribute.
481 void StringAttrPrettyNameOp::getAsmResultNames(
482     function_ref<void(Value, StringRef)> setNameFn) {
483 
484   auto value = names();
485   for (size_t i = 0, e = value.size(); i != e; ++i)
486     if (auto str = value[i].dyn_cast<StringAttr>())
487       if (!str.getValue().empty())
488         setNameFn(getResult(i), str.getValue());
489 }
490 
491 //===----------------------------------------------------------------------===//
492 // Dialect Registration
493 //===----------------------------------------------------------------------===//
494 
495 // Static initialization for Test dialect registration.
496 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
497 
498 #include "TestOpEnums.cpp.inc"
499 #include "TestOpStructs.cpp.inc"
500 
501 #define GET_OP_CLASSES
502 #include "TestOps.cpp.inc"
503