1 //===- TestTypes.cpp - Test passes for MLIR types -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestTypes.h"
10 #include "TestDialect.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 using namespace test;
15 
16 namespace {
17 struct TestRecursiveTypesPass
18     : public PassWrapper<TestRecursiveTypesPass, OperationPass<func::FuncOp>> {
19   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestRecursiveTypesPass)
20 
21   LogicalResult createIRWithTypes();
22 
getArgument__anone5d950a70111::TestRecursiveTypesPass23   StringRef getArgument() const final { return "test-recursive-types"; }
getDescription__anone5d950a70111::TestRecursiveTypesPass24   StringRef getDescription() const final {
25     return "Test support for recursive types";
26   }
runOnOperation__anone5d950a70111::TestRecursiveTypesPass27   void runOnOperation() override {
28     func::FuncOp func = getOperation();
29 
30     // Just make sure recursive types are printed and parsed.
31     if (func.getName() == "roundtrip")
32       return;
33 
34     // Create a recursive type and print it as a part of a dummy op.
35     if (func.getName() == "create") {
36       if (failed(createIRWithTypes()))
37         signalPassFailure();
38       return;
39     }
40 
41     // Unknown key.
42     func.emitOpError() << "unexpected function name";
43     signalPassFailure();
44   }
45 };
46 } // namespace
47 
createIRWithTypes()48 LogicalResult TestRecursiveTypesPass::createIRWithTypes() {
49   MLIRContext *ctx = &getContext();
50   func::FuncOp func = getOperation();
51   auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name");
52   if (failed(type.setBody(type)))
53     return func.emitError("expected to be able to set the type body");
54 
55   // Setting the same body is fine.
56   if (failed(type.setBody(type)))
57     return func.emitError(
58         "expected to be able to set the type body to the same value");
59 
60   // Setting a different body is not.
61   if (succeeded(type.setBody(IndexType::get(ctx))))
62     return func.emitError(
63         "not expected to be able to change function body more than once");
64 
65   // Expecting to get the same type for the same name.
66   auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name");
67   if (type != other)
68     return func.emitError("expected type name to be the uniquing key");
69 
70   // Create the op to check how the type is printed.
71   OperationState state(func.getLoc(), "test.dummy_type_test_op");
72   state.addTypes(type);
73   func.getBody().front().push_front(Operation::create(state));
74 
75   return success();
76 }
77 
78 namespace mlir {
79 namespace test {
80 
registerTestRecursiveTypesPass()81 void registerTestRecursiveTypesPass() {
82   PassRegistration<TestRecursiveTypesPass>();
83 }
84 
85 } // namespace test
86 } // namespace mlir
87