1 //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 /// A pass for testing call graph type decomposition.
20 ///
21 /// This instantiates the patterns with a TypeConverter and ValueDecomposer
22 /// that splits tuple types into their respective element types.
23 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
24 struct TestDecomposeCallGraphTypes
25     : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon61c7882a0111::TestDecomposeCallGraphTypes26   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes)
27 
28   void getDependentDialects(DialectRegistry &registry) const override {
29     registry.insert<test::TestDialect>();
30   }
getArgument__anon61c7882a0111::TestDecomposeCallGraphTypes31   StringRef getArgument() const final {
32     return "test-decompose-call-graph-types";
33   }
getDescription__anon61c7882a0111::TestDecomposeCallGraphTypes34   StringRef getDescription() const final {
35     return "Decomposes types at call graph boundaries.";
36   }
runOnOperation__anon61c7882a0111::TestDecomposeCallGraphTypes37   void runOnOperation() override {
38     ModuleOp module = getOperation();
39     auto *context = &getContext();
40     TypeConverter typeConverter;
41     ConversionTarget target(*context);
42     ValueDecomposer decomposer;
43     RewritePatternSet patterns(context);
44 
45     target.addLegalDialect<test::TestDialect>();
46 
47     target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
48       return typeConverter.isLegal(op.getOperandTypes());
49     });
50     target.addDynamicallyLegalOp<func::CallOp>(
51         [&](func::CallOp op) { return typeConverter.isLegal(op); });
52     target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
53       return typeConverter.isSignatureLegal(op.getFunctionType());
54     });
55 
56     typeConverter.addConversion([](Type type) { return type; });
57     typeConverter.addConversion(
58         [](TupleType tupleType, SmallVectorImpl<Type> &types) {
59           tupleType.getFlattenedTypes(types);
60           return success();
61         });
62 
63     decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
64                                               TupleType resultType, Value value,
65                                               SmallVectorImpl<Value> &values) {
66       for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
67         Value res = builder.create<test::GetTupleElementOp>(
68             loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
69         values.push_back(res);
70       }
71       return success();
72     });
73 
74     typeConverter.addArgumentMaterialization(
75         [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
76            Location loc) -> Optional<Value> {
77           if (inputs.size() == 1)
78             return llvm::None;
79           TupleType tuple = builder.getTupleType(inputs.getTypes());
80           Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
81           return value;
82         });
83 
84     populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
85                                             patterns);
86 
87     if (failed(applyPartialConversion(module, target, std::move(patterns))))
88       return signalPassFailure();
89   }
90 };
91 
92 } // namespace
93 
94 namespace mlir {
95 namespace test {
registerTestDecomposeCallGraphTypes()96 void registerTestDecomposeCallGraphTypes() {
97   PassRegistration<TestDecomposeCallGraphTypes>();
98 }
99 } // namespace test
100 } // namespace mlir
101