1 //===- DecomposeCallGraphTypes.cpp - CG type decomposition ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h"
10 #include "mlir/Dialect/Func/IR/FuncOps.h"
11 #include "mlir/IR/BuiltinOps.h"
12 
13 using namespace mlir;
14 using namespace mlir::func;
15 
16 //===----------------------------------------------------------------------===//
17 // ValueDecomposer
18 //===----------------------------------------------------------------------===//
19 
decomposeValue(OpBuilder & builder,Location loc,Type type,Value value,SmallVectorImpl<Value> & results)20 void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
21                                      Type type, Value value,
22                                      SmallVectorImpl<Value> &results) {
23   for (auto &conversion : decomposeValueConversions)
24     if (conversion(builder, loc, type, value, results))
25       return;
26   results.push_back(value);
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // DecomposeCallGraphTypesOpConversionPattern
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 /// Base OpConversionPattern class to make a ValueDecomposer available to
35 /// inherited patterns.
36 template <typename SourceOp>
37 class DecomposeCallGraphTypesOpConversionPattern
38     : public OpConversionPattern<SourceOp> {
39 public:
DecomposeCallGraphTypesOpConversionPattern(TypeConverter & typeConverter,MLIRContext * context,ValueDecomposer & decomposer,PatternBenefit benefit=1)40   DecomposeCallGraphTypesOpConversionPattern(TypeConverter &typeConverter,
41                                              MLIRContext *context,
42                                              ValueDecomposer &decomposer,
43                                              PatternBenefit benefit = 1)
44       : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45         decomposer(decomposer) {}
46 
47 protected:
48   ValueDecomposer &decomposer;
49 };
50 } // namespace
51 
52 //===----------------------------------------------------------------------===//
53 // DecomposeCallGraphTypesForFuncArgs
54 //===----------------------------------------------------------------------===//
55 
56 namespace {
57 /// Expand function arguments according to the provided TypeConverter and
58 /// ValueDecomposer.
59 struct DecomposeCallGraphTypesForFuncArgs
60     : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61   using DecomposeCallGraphTypesOpConversionPattern::
62       DecomposeCallGraphTypesOpConversionPattern;
63 
64   LogicalResult
matchAndRewrite__anon1954b0190211::DecomposeCallGraphTypesForFuncArgs65   matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const final {
67     auto functionType = op.getFunctionType();
68 
69     // Convert function arguments using the provided TypeConverter.
70     TypeConverter::SignatureConversion conversion(functionType.getNumInputs());
71     for (const auto &argType : llvm::enumerate(functionType.getInputs())) {
72       SmallVector<Type, 2> decomposedTypes;
73       if (failed(typeConverter->convertType(argType.value(), decomposedTypes)))
74         return failure();
75       if (!decomposedTypes.empty())
76         conversion.addInputs(argType.index(), decomposedTypes);
77     }
78 
79     // If the SignatureConversion doesn't apply, bail out.
80     if (failed(rewriter.convertRegionTypes(&op.getBody(), *getTypeConverter(),
81                                            &conversion)))
82       return failure();
83 
84     // Update the signature of the function.
85     SmallVector<Type, 2> newResultTypes;
86     if (failed(typeConverter->convertTypes(functionType.getResults(),
87                                            newResultTypes)))
88       return failure();
89     rewriter.updateRootInPlace(op, [&] {
90       op.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
91                                           newResultTypes));
92     });
93     return success();
94   }
95 };
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // DecomposeCallGraphTypesForReturnOp
100 //===----------------------------------------------------------------------===//
101 
102 namespace {
103 /// Expand return operands according to the provided TypeConverter and
104 /// ValueDecomposer.
105 struct DecomposeCallGraphTypesForReturnOp
106     : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
107   using DecomposeCallGraphTypesOpConversionPattern::
108       DecomposeCallGraphTypesOpConversionPattern;
109   LogicalResult
matchAndRewrite__anon1954b0190411::DecomposeCallGraphTypesForReturnOp110   matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
111                   ConversionPatternRewriter &rewriter) const final {
112     SmallVector<Value, 2> newOperands;
113     for (Value operand : adaptor.getOperands())
114       decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
115                                 operand, newOperands);
116     rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
117     return success();
118   }
119 };
120 } // namespace
121 
122 //===----------------------------------------------------------------------===//
123 // DecomposeCallGraphTypesForCallOp
124 //===----------------------------------------------------------------------===//
125 
126 namespace {
127 /// Expand call op operands and results according to the provided TypeConverter
128 /// and ValueDecomposer.
129 struct DecomposeCallGraphTypesForCallOp
130     : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131   using DecomposeCallGraphTypesOpConversionPattern::
132       DecomposeCallGraphTypesOpConversionPattern;
133 
134   LogicalResult
matchAndRewrite__anon1954b0190511::DecomposeCallGraphTypesForCallOp135   matchAndRewrite(CallOp op, OpAdaptor adaptor,
136                   ConversionPatternRewriter &rewriter) const final {
137 
138     // Create the operands list of the new `CallOp`.
139     SmallVector<Value, 2> newOperands;
140     for (Value operand : adaptor.getOperands())
141       decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
142                                 operand, newOperands);
143 
144     // Create the new result types for the new `CallOp` and track the indices in
145     // the new call op's results that correspond to the old call op's results.
146     //
147     // expandedResultIndices[i] = "list of new result indices that old result i
148     // expanded to".
149     SmallVector<Type, 2> newResultTypes;
150     SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
151     for (Type resultType : op.getResultTypes()) {
152       unsigned oldSize = newResultTypes.size();
153       if (failed(typeConverter->convertType(resultType, newResultTypes)))
154         return failure();
155       auto &resultMapping = expandedResultIndices.emplace_back();
156       for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157         resultMapping.push_back(i);
158     }
159 
160     CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161                                                newResultTypes, newOperands);
162 
163     // Build a replacement value for each result to replace its uses. If a
164     // result has multiple mapping values, it needs to be materialized as a
165     // single value.
166     SmallVector<Value, 2> replacedValues;
167     replacedValues.reserve(op.getNumResults());
168     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169       auto decomposedValues = llvm::to_vector<6>(
170           llvm::map_range(expandedResultIndices[i],
171                           [&](unsigned i) { return newCallOp.getResult(i); }));
172       if (decomposedValues.empty()) {
173         // No replacement is required.
174         replacedValues.push_back(nullptr);
175       } else if (decomposedValues.size() == 1) {
176         replacedValues.push_back(decomposedValues.front());
177       } else {
178         // Materialize a single Value to replace the original Value.
179         Value materialized = getTypeConverter()->materializeArgumentConversion(
180             rewriter, op.getLoc(), op.getType(i), decomposedValues);
181         replacedValues.push_back(materialized);
182       }
183     }
184     rewriter.replaceOp(op, replacedValues);
185     return success();
186   }
187 };
188 } // namespace
189 
populateDecomposeCallGraphTypesPatterns(MLIRContext * context,TypeConverter & typeConverter,ValueDecomposer & decomposer,RewritePatternSet & patterns)190 void mlir::populateDecomposeCallGraphTypesPatterns(
191     MLIRContext *context, TypeConverter &typeConverter,
192     ValueDecomposer &decomposer, RewritePatternSet &patterns) {
193   patterns
194       .add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195            DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196                                                decomposer);
197 }
198