1 //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/IR/Module.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/TypeUtilities.h"
15 #include "mlir/Transforms/FoldUtils.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 #include "llvm/ADT/StringSwitch.h"
18 
19 using namespace mlir;
20 
21 //===----------------------------------------------------------------------===//
22 // TestDialect Interfaces
23 //===----------------------------------------------------------------------===//
24 
25 namespace {
26 
27 // Test support for interacting with the AsmPrinter.
28 struct TestOpAsmInterface : public OpAsmDialectInterface {
29   using OpAsmDialectInterface::OpAsmDialectInterface;
30 
31   void getAsmResultNames(Operation *op,
32                          OpAsmSetValueNameFn setNameFn) const final {
33     if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
34       setNameFn(asmOp, "result");
35   }
36 
37   void getAsmBlockArgumentNames(Block *block,
38                                 OpAsmSetValueNameFn setNameFn) const final {
39     auto op = block->getParentOp();
40     auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
41     if (!arrayAttr)
42       return;
43     auto args = block->getArguments();
44     auto e = std::min(arrayAttr.size(), args.size());
45     for (unsigned i = 0; i < e; ++i) {
46       if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
47         setNameFn(args[i], strAttr.getValue());
48     }
49   }
50 };
51 
52 struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
53   using OpFolderDialectInterface::OpFolderDialectInterface;
54 
55   /// Registered hook to check if the given region, which is attached to an
56   /// operation that is *not* isolated from above, should be used when
57   /// materializing constants.
58   bool shouldMaterializeInto(Region *region) const final {
59     // If this is a one region operation, then insert into it.
60     return isa<OneRegionOp>(region->getParentOp());
61   }
62 };
63 
64 /// This class defines the interface for handling inlining with standard
65 /// operations.
66 struct TestInlinerInterface : public DialectInlinerInterface {
67   using DialectInlinerInterface::DialectInlinerInterface;
68 
69   //===--------------------------------------------------------------------===//
70   // Analysis Hooks
71   //===--------------------------------------------------------------------===//
72 
73   bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
74     // Inlining into test dialect regions is legal.
75     return true;
76   }
77   bool isLegalToInline(Operation *, Region *,
78                        BlockAndValueMapping &) const final {
79     return true;
80   }
81 
82   bool shouldAnalyzeRecursively(Operation *op) const final {
83     // Analyze recursively if this is not a functional region operation, it
84     // froms a separate functional scope.
85     return !isa<FunctionalRegionOp>(op);
86   }
87 
88   //===--------------------------------------------------------------------===//
89   // Transformation Hooks
90   //===--------------------------------------------------------------------===//
91 
92   /// Handle the given inlined terminator by replacing it with a new operation
93   /// as necessary.
94   void handleTerminator(Operation *op,
95                         ArrayRef<Value> valuesToRepl) const final {
96     // Only handle "test.return" here.
97     auto returnOp = dyn_cast<TestReturnOp>(op);
98     if (!returnOp)
99       return;
100 
101     // Replace the values directly with the return operands.
102     assert(returnOp.getNumOperands() == valuesToRepl.size());
103     for (const auto &it : llvm::enumerate(returnOp.getOperands()))
104       valuesToRepl[it.index()].replaceAllUsesWith(it.value());
105   }
106 
107   /// Attempt to materialize a conversion for a type mismatch between a call
108   /// from this dialect, and a callable region. This method should generate an
109   /// operation that takes 'input' as the only operand, and produces a single
110   /// result of 'resultType'. If a conversion can not be generated, nullptr
111   /// should be returned.
112   Operation *materializeCallConversion(OpBuilder &builder, Value input,
113                                        Type resultType,
114                                        Location conversionLoc) const final {
115     // Only allow conversion for i16/i32 types.
116     if (!(resultType.isSignlessInteger(16) ||
117           resultType.isSignlessInteger(32)) ||
118         !(input.getType().isSignlessInteger(16) ||
119           input.getType().isSignlessInteger(32)))
120       return nullptr;
121     return builder.create<TestCastOp>(conversionLoc, resultType, input);
122   }
123 };
124 } // end anonymous namespace
125 
126 //===----------------------------------------------------------------------===//
127 // TestDialect
128 //===----------------------------------------------------------------------===//
129 
130 TestDialect::TestDialect(MLIRContext *context)
131     : Dialect(getDialectNamespace(), context) {
132   addOperations<
133 #define GET_OP_LIST
134 #include "TestOps.cpp.inc"
135       >();
136   addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
137                 TestInlinerInterface>();
138   allowUnknownOperations();
139 }
140 
141 LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
142                                                     NamedAttribute namedAttr) {
143   if (namedAttr.first == "test.invalid_attr")
144     return op->emitError() << "invalid to use 'test.invalid_attr'";
145   return success();
146 }
147 
148 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
149                                                     unsigned regionIndex,
150                                                     unsigned argIndex,
151                                                     NamedAttribute namedAttr) {
152   if (namedAttr.first == "test.invalid_attr")
153     return op->emitError() << "invalid to use 'test.invalid_attr'";
154   return success();
155 }
156 
157 LogicalResult
158 TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
159                                          unsigned resultIndex,
160                                          NamedAttribute namedAttr) {
161   if (namedAttr.first == "test.invalid_attr")
162     return op->emitError() << "invalid to use 'test.invalid_attr'";
163   return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // TestBranchOp
168 //===----------------------------------------------------------------------===//
169 
170 Optional<MutableOperandRange>
171 TestBranchOp::getMutableSuccessorOperands(unsigned index) {
172   assert(index == 0 && "invalid successor index");
173   return targetOperandsMutable();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // Test IsolatedRegionOp - parse passthrough region arguments.
178 //===----------------------------------------------------------------------===//
179 
180 static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
181                                          OperationState &result) {
182   OpAsmParser::OperandType argInfo;
183   Type argType = parser.getBuilder().getIndexType();
184 
185   // Parse the input operand.
186   if (parser.parseOperand(argInfo) ||
187       parser.resolveOperand(argInfo, argType, result.operands))
188     return failure();
189 
190   // Parse the body region, and reuse the operand info as the argument info.
191   Region *body = result.addRegion();
192   return parser.parseRegion(*body, argInfo, argType,
193                             /*enableNameShadowing=*/true);
194 }
195 
196 static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
197   p << "test.isolated_region ";
198   p.printOperand(op.getOperand());
199   p.shadowRegionArgs(op.region(), op.getOperand());
200   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // Test AffineScopeOp
205 //===----------------------------------------------------------------------===//
206 
207 static ParseResult parseAffineScopeOp(OpAsmParser &parser,
208                                       OperationState &result) {
209   // Parse the body region, and reuse the operand info as the argument info.
210   Region *body = result.addRegion();
211   return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
212 }
213 
214 static void print(OpAsmPrinter &p, AffineScopeOp op) {
215   p << "test.affine_scope ";
216   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Test parser.
221 //===----------------------------------------------------------------------===//
222 
223 static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
224                                          OperationState &result) {
225   StringRef keyword;
226   if (parser.parseKeyword(&keyword))
227     return failure();
228   result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
229   return success();
230 }
231 
232 static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
233   p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
238 
239 static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
240                                          OperationState &result) {
241   if (parser.parseKeyword("wraps"))
242     return failure();
243 
244   // Parse the wrapped op in a region
245   Region &body = *result.addRegion();
246   body.push_back(new Block);
247   Block &block = body.back();
248   Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
249   if (!wrapped_op)
250     return failure();
251 
252   // Create a return terminator in the inner region, pass as operand to the
253   // terminator the returned values from the wrapped operation.
254   SmallVector<Value, 8> return_operands(wrapped_op->getResults());
255   OpBuilder builder(parser.getBuilder().getContext());
256   builder.setInsertionPointToEnd(&block);
257   builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
258 
259   // Get the results type for the wrapping op from the terminator operands.
260   Operation &return_op = body.back().back();
261   result.types.append(return_op.operand_type_begin(),
262                       return_op.operand_type_end());
263 
264   // Use the location of the wrapped op for the "test.wrapping_region" op.
265   result.location = wrapped_op->getLoc();
266 
267   return success();
268 }
269 
270 static void print(OpAsmPrinter &p, WrappingRegionOp op) {
271   p << op.getOperationName() << " wraps ";
272   p.printGenericOp(&op.region().front().front());
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Test PolyForOp - parse list of region arguments.
277 //===----------------------------------------------------------------------===//
278 
279 static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
280   SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
281   // Parse list of region arguments without a delimiter.
282   if (parser.parseRegionArgumentList(ivsInfo))
283     return failure();
284 
285   // Parse the body region.
286   Region *body = result.addRegion();
287   auto &builder = parser.getBuilder();
288   SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
289   return parser.parseRegion(*body, ivsInfo, argTypes);
290 }
291 
292 //===----------------------------------------------------------------------===//
293 // Test removing op with inner ops.
294 //===----------------------------------------------------------------------===//
295 
296 namespace {
297 struct TestRemoveOpWithInnerOps
298     : public OpRewritePattern<TestOpWithRegionPattern> {
299   using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
300 
301   LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
302                                 PatternRewriter &rewriter) const override {
303     rewriter.eraseOp(op);
304     return success();
305   }
306 };
307 } // end anonymous namespace
308 
309 void TestOpWithRegionPattern::getCanonicalizationPatterns(
310     OwningRewritePatternList &results, MLIRContext *context) {
311   results.insert<TestRemoveOpWithInnerOps>(context);
312 }
313 
314 OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
315   return operand();
316 }
317 
318 LogicalResult TestOpWithVariadicResultsAndFolder::fold(
319     ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
320   for (Value input : this->operands()) {
321     results.push_back(input);
322   }
323   return success();
324 }
325 
326 OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
327   assert(operands.size() == 1);
328   if (operands.front()) {
329     setAttr("attr", operands.front());
330     return getResult();
331   }
332   return {};
333 }
334 
335 LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
336     MLIRContext *, Optional<Location> location, ValueRange operands,
337     DictionaryAttr attributes, RegionRange regions,
338     SmallVectorImpl<Type> &inferredReturnTypes) {
339   if (operands[0].getType() != operands[1].getType()) {
340     return emitOptionalError(location, "operand type mismatch ",
341                              operands[0].getType(), " vs ",
342                              operands[1].getType());
343   }
344   inferredReturnTypes.assign({operands[0].getType()});
345   return success();
346 }
347 
348 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
349     MLIRContext *context, Optional<Location> location, ValueRange operands,
350     DictionaryAttr attributes, RegionRange regions,
351     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
352   // Create return type consisting of the last element of the first operand.
353   auto operandType = *operands.getTypes().begin();
354   auto sval = operandType.dyn_cast<ShapedType>();
355   if (!sval) {
356     return emitOptionalError(location, "only shaped type operands allowed");
357   }
358   int64_t dim =
359       sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
360   auto type = IntegerType::get(17, context);
361   inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
362   return success();
363 }
364 
365 LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
366     OpBuilder &builder, llvm::SmallVectorImpl<Value> &shapes) {
367   shapes = SmallVector<Value, 1>{
368       builder.createOrFold<mlir::DimOp>(getLoc(), getOperand(0), 0)};
369   return success();
370 }
371 
372 //===----------------------------------------------------------------------===//
373 // Test SideEffect interfaces
374 //===----------------------------------------------------------------------===//
375 
376 namespace {
377 /// A test resource for side effects.
378 struct TestResource : public SideEffects::Resource::Base<TestResource> {
379   StringRef getName() final { return "<Test>"; }
380 };
381 } // end anonymous namespace
382 
383 void SideEffectOp::getEffects(
384     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
385   // Check for an effects attribute on the op instance.
386   ArrayAttr effectsAttr = getAttrOfType<ArrayAttr>("effects");
387   if (!effectsAttr)
388     return;
389 
390   // If there is one, it is an array of dictionary attributes that hold
391   // information on the effects of this operation.
392   for (Attribute element : effectsAttr) {
393     DictionaryAttr effectElement = element.cast<DictionaryAttr>();
394 
395     // Get the specific memory effect.
396     MemoryEffects::Effect *effect =
397         llvm::StringSwitch<MemoryEffects::Effect *>(
398             effectElement.get("effect").cast<StringAttr>().getValue())
399             .Case("allocate", MemoryEffects::Allocate::get())
400             .Case("free", MemoryEffects::Free::get())
401             .Case("read", MemoryEffects::Read::get())
402             .Case("write", MemoryEffects::Write::get());
403 
404     // Check for a result to affect.
405     Value value;
406     if (effectElement.get("on_result"))
407       value = getResult();
408 
409     // Check for a non-default resource to use.
410     SideEffects::Resource *resource = SideEffects::DefaultResource::get();
411     if (effectElement.get("test_resource"))
412       resource = TestResource::get();
413 
414     effects.emplace_back(effect, value, resource);
415   }
416 }
417 
418 //===----------------------------------------------------------------------===//
419 // StringAttrPrettyNameOp
420 //===----------------------------------------------------------------------===//
421 
422 // This op has fancy handling of its SSA result name.
423 static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser,
424                                                OperationState &result) {
425   // Add the result types.
426   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
427     result.addTypes(parser.getBuilder().getIntegerType(32));
428 
429   if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
430     return failure();
431 
432   // If the attribute dictionary contains no 'names' attribute, infer it from
433   // the SSA name (if specified).
434   bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
435     return attr.first == "names";
436   });
437 
438   // If there was no name specified, check to see if there was a useful name
439   // specified in the asm file.
440   if (hadNames || parser.getNumResults() == 0)
441     return success();
442 
443   SmallVector<StringRef, 4> names;
444   auto *context = result.getContext();
445 
446   for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
447     auto resultName = parser.getResultName(i);
448     StringRef nameStr;
449     if (!resultName.first.empty() && !isdigit(resultName.first[0]))
450       nameStr = resultName.first;
451 
452     names.push_back(nameStr);
453   }
454 
455   auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
456   result.attributes.push_back({Identifier::get("names", context), namesAttr});
457   return success();
458 }
459 
460 static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) {
461   p << "test.string_attr_pretty_name";
462 
463   // Note that we only need to print the "name" attribute if the asmprinter
464   // result name disagrees with it.  This can happen in strange cases, e.g.
465   // when there are conflicts.
466   bool namesDisagree = op.names().size() != op.getNumResults();
467 
468   SmallString<32> resultNameStr;
469   for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) {
470     resultNameStr.clear();
471     llvm::raw_svector_ostream tmpStream(resultNameStr);
472     p.printOperand(op.getResult(i), tmpStream);
473 
474     auto expectedName = op.names()[i].dyn_cast<StringAttr>();
475     if (!expectedName ||
476         tmpStream.str().drop_front() != expectedName.getValue()) {
477       namesDisagree = true;
478     }
479   }
480 
481   if (namesDisagree)
482     p.printOptionalAttrDictWithKeyword(op.getAttrs());
483   else
484     p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"});
485 }
486 
487 // We set the SSA name in the asm syntax to the contents of the name
488 // attribute.
489 void StringAttrPrettyNameOp::getAsmResultNames(
490     function_ref<void(Value, StringRef)> setNameFn) {
491 
492   auto value = names();
493   for (size_t i = 0, e = value.size(); i != e; ++i)
494     if (auto str = value[i].dyn_cast<StringAttr>())
495       if (!str.getValue().empty())
496         setNameFn(getResult(i), str.getValue());
497 }
498 
499 //===----------------------------------------------------------------------===//
500 // Dialect Registration
501 //===----------------------------------------------------------------------===//
502 
503 // Static initialization for Test dialect registration.
504 static mlir::DialectRegistration<mlir::TestDialect> testDialect;
505 
506 #include "TestOpEnums.cpp.inc"
507 #include "TestOpStructs.cpp.inc"
508 
509 #define GET_OP_CLASSES
510 #include "TestOps.cpp.inc"
511