1 //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/BuiltinOps.h"
12
13 using namespace mlir;
14 using namespace mlir::func;
15
16 //===----------------------------------------------------------------------===//
17 // ValueDecomposer
18 //===----------------------------------------------------------------------===//
19
decomposeValue(OpBuilder & builder,Location loc,Type type,Value value,SmallVectorImpl<Value> & results)20 void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
21 Type type, Value value,
22 SmallVectorImpl<Value> &results) {
23 for (auto &conversion : decomposeValueConversions)
24 if (conversion(builder, loc, type, value, results))
25 return;
26 results.push_back(value);
27 }
28
29 //===----------------------------------------------------------------------===//
30 // DecomposeCallGraphTypesOpConversionPattern
31 //===----------------------------------------------------------------------===//
32
33 namespace {
34 /// Base OpConversionPattern class to make a ValueDecomposer available to
35 /// inherited patterns.
36 template <typename SourceOp>
37 class DecomposeCallGraphTypesOpConversionPattern
38 : public OpConversionPattern<SourceOp> {
39 public:
DecomposeCallGraphTypesOpConversionPattern(TypeConverter & typeConverter,MLIRContext * context,ValueDecomposer & decomposer,PatternBenefit benefit=1)40 DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
41 MLIRContext *context,
42 ValueDecomposer &decomposer,
43 PatternBenefit benefit = 1)
44 : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45 decomposer(decomposer) {}
46
47 protected:
48 ValueDecomposer &decomposer;
49 };
50 } // namespace
51
52 //===----------------------------------------------------------------------===//
53 // DecomposeCallGraphTypesForFuncArgs
54 //===----------------------------------------------------------------------===//
55
56 namespace {
57 /// Expand function arguments according to the provided TypeConverter and
58 /// ValueDecomposer.
59 struct DecomposeCallGraphTypesForFuncArgs
60 : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61 using DecomposeCallGraphTypesOpConversionPattern::
62 DecomposeCallGraphTypesOpConversionPattern;
63
64 LogicalResult
matchAndRewrite__anon1954b0190211::DecomposeCallGraphTypesForFuncArgs65 matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
66 ConversionPatternRewriter &rewriter) const final {
67 auto functionType = op.getFunctionType();
68
69 // Convert function arguments using the provided TypeConverter.
70 TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
71 for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
72 SmallVector<Type, 2> decomposedTypes;
73 if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
74 return failure();
75 if (!decomposedTypes.empty())
76 conversion.addInputs(argType.index(), decomposedTypes);
77 }
78
79 // If the SignatureConversion doesn't apply, bail out.
80 if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
81 &conversion)))
82 return failure();
83
84 // Update the signature of the function.
85 SmallVector<Type, 2> newResultTypes;
86 if (failed(typeConverter->convertTypes(functionType.getResults(),
87 newResultTypes)))
88 return failure();
89 rewriter.updateRootInPlace(op, [&] {
90 op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
91 newResultTypes));
92 });
93 return success();
94 }
95 };
96 } // namespace
97
98 //===----------------------------------------------------------------------===//
99 // DecomposeCallGraphTypesForReturnOp
100 //===----------------------------------------------------------------------===//
101
102 namespace {
103 /// Expand return operands according to the provided TypeConverter and
104 /// ValueDecomposer.
105 struct DecomposeCallGraphTypesForReturnOp
106 : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
107 using DecomposeCallGraphTypesOpConversionPattern::
108 DecomposeCallGraphTypesOpConversionPattern;
109 LogicalResult
matchAndRewrite__anon1954b0190411::DecomposeCallGraphTypesForReturnOp110 matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
111 ConversionPatternRewriter &rewriter) const final {
112 SmallVector<Value, 2> newOperands;
113 for (Value operand : adaptor.getOperands())
114 decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
115 operand, newOperands);
116 rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
117 return success();
118 }
119 };
120 } // namespace
121
122 //===----------------------------------------------------------------------===//
123 // DecomposeCallGraphTypesForCallOp
124 //===----------------------------------------------------------------------===//
125
126 namespace {
127 /// Expand call op operands and results according to the provided TypeConverter
128 /// and ValueDecomposer.
129 struct DecomposeCallGraphTypesForCallOp
130 : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131 using DecomposeCallGraphTypesOpConversionPattern::
132 DecomposeCallGraphTypesOpConversionPattern;
133
134 LogicalResult
matchAndRewrite__anon1954b0190511::DecomposeCallGraphTypesForCallOp135 matchAndRewrite(CallOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const final {
137
138 // Create the operands list of the new `CallOp`.
139 SmallVector<Value, 2> newOperands;
140 for (Value operand : adaptor.getOperands())
141 decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
142 operand, newOperands);
143
144 // Create the new result types for the new `CallOp` and track the indices in
145 // the new call op's results that correspond to the old call op's results.
146 //
147 // expandedResultIndices[i] = "list of new result indices that old result i
148 // expanded to".
149 SmallVector<Type, 2> newResultTypes;
150 SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
151 for (Type resultType : op.getResultTypes()) {
152 unsigned oldSize = newResultTypes.size();
153 if (failed(typeConverter->convertType(resultType, newResultTypes)))
154 return failure();
155 auto &resultMapping = expandedResultIndices.emplace_back();
156 for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157 resultMapping.push_back(i);
158 }
159
160 CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161 newResultTypes, newOperands);
162
163 // Build a replacement value for each result to replace its uses. If a
164 // result has multiple mapping values, it needs to be materialized as a
165 // single value.
166 SmallVector<Value, 2> replacedValues;
167 replacedValues.reserve(op.getNumResults());
168 for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169 auto decomposedValues = llvm::to_vector<6>(
170 llvm::map_range(expandedResultIndices[i],
171 [&](unsigned i) { return newCallOp.getResult(i); }));
172 if (decomposedValues.empty()) {
173 // No replacement is required.
174 replacedValues.push_back(nullptr);
175 } else if (decomposedValues.size() == 1) {
176 replacedValues.push_back(decomposedValues.front());
177 } else {
178 // Materialize a single Value to replace the original Value.
179 Value materialized = getTypeConverter()->materializeArgumentConversion(
180 rewriter, op.getLoc(), op.getType(i), decomposedValues);
181 replacedValues.push_back(materialized);
182 }
183 }
184 rewriter.replaceOp(op, replacedValues);
185 return success();
186 }
187 };
188 } // namespace
189
populateDecomposeCallGraphTypesPatterns(MLIRContext * context,TypeConverter & typeConverter,ValueDecomposer & decomposer,RewritePatternSet & patterns)190 void mlir::populateDecomposeCallGraphTypesPatterns(
191 MLIRContext *context, TypeConverter &typeConverter,
192 ValueDecomposer &decomposer, RewritePatternSet &patterns) {
193 patterns
194 .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195 DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196 decomposer);
197 }
198