1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
17 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 using namespace mlir::tosa;
24 
25 /// There are two potential ways implementing broadcast:
26 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
27 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
28 /// This pass implements b (numpy style) now.
29 
30 /// In this pass, we insert RESHAPE operators to increase the rank of the
31 /// lower rank operand as a first step in the broadcasting process. The TOSA
32 /// operators that support broadcast require that the rank of the operands
33 /// are equal.
34 
35 // Examples:
36 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
37 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
38 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
39 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
40 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
41 
42 static LogicalResult
43 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
44                      ArrayRef<int64_t> lowerRankShape,
45                      SmallVectorImpl<int64_t> &reshapeOutputShape) {
46   // Initialize new shapes with [1] * higherRank.
47   int64_t higherRank = higherRankShape.size();
48   int64_t lowerRank = lowerRankShape.size();
49 
50   reshapeOutputShape.assign(higherRank, 1);
51 
52   int64_t higherRankDim;
53   int64_t lowerRankDim;
54 
55   for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
56        i--, j--) {
57     higherRankDim = higherRankShape[i];
58     lowerRankDim = lowerRankShape[j];
59 
60     if (lowerRankDim == 1 && higherRankDim > 1)
61       reshapeOutputShape[i] = 1;
62     else if ((lowerRankDim > 1 && higherRankDim == 1) ||
63              (lowerRankDim == higherRankDim))
64       reshapeOutputShape[i] = lowerRankDim;
65     else if (higherRankDim != lowerRankDim)
66       return failure();
67   }
68   return success();
69 }
70 
71 /// Common code to create the reshape op where necessary to make the rank of the
72 /// operations equal. Returns the updated input1 and input2 for the original
73 /// input. The caller is expected to use these to rewrite the original operator
74 /// with the RESHAPE now in the graph.
75 static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
76                                           Location loc,
77                                           RankedTensorType outputType,
78                                           Value input1, Value input2,
79                                           Value &outInput1, Value &outInput2) {
80   auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
81   auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
82 
83   if (!input1Ty || !input2Ty)
84     return failure();
85 
86   int64_t input1Rank = input1Ty.getRank();
87   int64_t input2Rank = input2Ty.getRank();
88 
89   Value higherTensorValue, lowerTensorValue;
90   // Cannot rewrite as its already correct.
91   if (input1Rank == input2Rank)
92     return failure();
93 
94   if (input1Rank > input2Rank) {
95     higherTensorValue = input1;
96     lowerTensorValue = input2;
97   } else {
98     higherTensorValue = input2;
99     lowerTensorValue = input1;
100   }
101 
102   ArrayRef<int64_t> higherRankShape =
103       higherTensorValue.getType().cast<RankedTensorType>().getShape();
104   (void)higherRankShape;
105   ArrayRef<int64_t> lowerRankShape =
106       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
107 
108   SmallVector<int64_t, 4> reshapeOutputShape;
109 
110   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
111           .failed())
112     return failure();
113 
114   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
115   auto reshapeOutputType = RankedTensorType::get(
116       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
117 
118   // Verify the rank agrees with the output type if the output type is ranked.
119   if (outputType) {
120     if (outputType.getShape().size() != reshapeOutputShape.size() ||
121         outputType.getShape().size() != higherRankShape.size())
122       return failure();
123   }
124 
125   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
126       loc, reshapeOutputType, lowerTensorValue,
127       rewriter.getI64ArrayAttr(reshapeOutputShape));
128 
129   if (input1Rank > input2Rank) {
130     outInput1 = higherTensorValue;
131     outInput2 = reshapeLower.getResult();
132   } else {
133     outInput1 = reshapeLower.getResult();
134     outInput2 = higherTensorValue;
135   }
136 
137   return success();
138 }
139 
140 namespace {
141 template <typename OpTy>
142 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
143   using OpRewritePattern<OpTy>::OpRewritePattern;
144 
145   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
146                                 PatternRewriter &rewriter) const override {
147 
148     Value input1 = tosaBinaryOp.input1();
149     Value input2 = tosaBinaryOp.input2();
150     Value output = tosaBinaryOp.getResult();
151 
152     auto outputType = output.getType().dyn_cast<RankedTensorType>();
153     if (!outputType)
154       return failure();
155 
156     Value outInput1, outInput2;
157     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
158                              input1, input2, outInput1, outInput2)
159             .failed())
160       return failure();
161 
162     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
163                                       outInput2);
164 
165     return success();
166   }
167 };
168 
169 // The MulOp has an extra parameter 'shift' not present in other elementwise
170 // binary ops, that necessitates special handling of its builder.
171 template <>
172 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
173   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
174 
175   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
176                                 PatternRewriter &rewriter) const override {
177 
178     Value input1 = tosaBinaryOp.input1();
179     Value input2 = tosaBinaryOp.input2();
180     int32_t shift = tosaBinaryOp.shift();
181     Value output = tosaBinaryOp.getResult();
182     auto outputType = output.getType().dyn_cast<RankedTensorType>();
183     if (!outputType)
184       return failure();
185 
186     Value outInput1, outInput2;
187     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
188                              input1, input2, outInput1, outInput2)
189             .failed())
190       return failure();
191 
192     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
193                                              outInput1, outInput2, shift);
194 
195     return success();
196   }
197 };
198 
199 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
200 // other elementwise binary ops, that necessitates special handling of its
201 // builder.
202 template <>
203 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
204     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
205   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
206 
207   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
208                                 PatternRewriter &rewriter) const override {
209 
210     Value input1 = tosaBinaryOp.input1();
211     Value input2 = tosaBinaryOp.input2();
212     int32_t round = tosaBinaryOp.round();
213     Value output = tosaBinaryOp.getResult();
214     auto outputType = output.getType().dyn_cast<RankedTensorType>();
215     if (!outputType)
216       return failure();
217 
218     Value outInput1, outInput2;
219     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
220                              input1, input2, outInput1, outInput2)
221             .failed())
222       return failure();
223 
224     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
225         tosaBinaryOp, outputType, outInput1, outInput2, round);
226 
227     return success();
228   }
229 };
230 } // namespace
231 
232 namespace {
233 /// Pass that enables broadcast by making all input arrays have the same
234 /// number of dimensions. Insert RESHAPE operations to lower rank operand
235 struct TosaMakeBroadcastable
236     : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
237 public:
238   void runOnFunction() override {
239     auto func = getFunction();
240     RewritePatternSet patterns(func.getContext());
241     MLIRContext *ctx = func.getContext();
242     // Add the generated patterns to the list.
243     patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
244     patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
245     patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
246     patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
247     patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
248     patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
249     patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
250     patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
251     patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
252     patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
253     patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
254     patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
255     patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
256     patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
257     patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
258     patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
259     patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
260     patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
261     patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
262     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
263   }
264 };
265 } // namespace
266 
267 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
268   return std::make_unique<TosaMakeBroadcastable>();
269 }
270