1 //===- TestDecomposeCallGraphTypes.cpp - Test CG type decomposition -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "TestDialect.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15
16 using namespace mlir;
17
18 namespace {
19 /// A pass for testing call graph type decomposition.
20 ///
21 /// This instantiates the patterns with a TypeConverter and ValueDecomposer
22 /// that splits tuple types into their respective element types.
23 /// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
24 struct TestDecomposeCallGraphTypes
25 : public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anon61c7882a0111::TestDecomposeCallGraphTypes26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDecomposeCallGraphTypes)
27
28 void getDependentDialects(DialectRegistry ®istry) const override {
29 registry.insert<test::TestDialect>();
30 }
getArgument__anon61c7882a0111::TestDecomposeCallGraphTypes31 StringRef getArgument() const final {
32 return "test-decompose-call-graph-types";
33 }
getDescription__anon61c7882a0111::TestDecomposeCallGraphTypes34 StringRef getDescription() const final {
35 return "Decomposes types at call graph boundaries.";
36 }
runOnOperation__anon61c7882a0111::TestDecomposeCallGraphTypes37 void runOnOperation() override {
38 ModuleOp module = getOperation();
39 auto *context = &getContext();
40 TypeConverter typeConverter;
41 ConversionTarget target(*context);
42 ValueDecomposer decomposer;
43 RewritePatternSet patterns(context);
44
45 target.addLegalDialect<test::TestDialect>();
46
47 target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
48 return typeConverter.isLegal(op.getOperandTypes());
49 });
50 target.addDynamicallyLegalOp<func::CallOp>(
51 [&](func::CallOp op) { return typeConverter.isLegal(op); });
52 target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
53 return typeConverter.isSignatureLegal(op.getFunctionType());
54 });
55
56 typeConverter.addConversion([](Type type) { return type; });
57 typeConverter.addConversion(
58 [](TupleType tupleType, SmallVectorImpl<Type> &types) {
59 tupleType.getFlattenedTypes(types);
60 return success();
61 });
62
63 decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
64 TupleType resultType, Value value,
65 SmallVectorImpl<Value> &values) {
66 for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
67 Value res = builder.create<test::GetTupleElementOp>(
68 loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
69 values.push_back(res);
70 }
71 return success();
72 });
73
74 typeConverter.addArgumentMaterialization(
75 [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
76 Location loc) -> Optional<Value> {
77 if (inputs.size() == 1)
78 return llvm::None;
79 TupleType tuple = builder.getTupleType(inputs.getTypes());
80 Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
81 return value;
82 });
83
84 populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
85 patterns);
86
87 if (failed(applyPartialConversion(module, target, std::move(patterns))))
88 return signalPassFailure();
89 }
90 };
91
92 } // namespace
93
94 namespace mlir {
95 namespace test {
registerTestDecomposeCallGraphTypes()96 void registerTestDecomposeCallGraphTypes() {
97 PassRegistration<TestDecomposeCallGraphTypes>();
98 }
99 } // namespace test
100 } // namespace mlir
101