1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 using namespace mlir::tosa;
23 
24 /// There are two potential ways implementing broadcast:
25 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
26 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
27 /// This pass implements b (numpy style) now.
28 
29 /// In this pass, we insert RESHAPE operators to increase the rank of the
30 /// lower rank operand as a first step in the broadcasting process. The TOSA
31 /// operators that support broadcast require that the rank of the operands
32 /// are equal.
33 
34 // Examples:
35 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c].
36 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c].
37 // If lower=[a], higher=[a, a], [a] reshaped into [1, a].
38 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
39 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
40 
41 static LogicalResult
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)42 computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
43                      ArrayRef<int64_t> lowerRankShape,
44                      SmallVectorImpl<int64_t> &reshapeOutputShape) {
45   // Initialize new shapes with [1] * higherRank.
46   int64_t higherRank = higherRankShape.size();
47   int64_t lowerRank = lowerRankShape.size();
48 
49   reshapeOutputShape.assign(higherRank, 1);
50 
51   int64_t higherRankDim;
52   int64_t lowerRankDim;
53 
54   for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0;
55        i--, j--) {
56     higherRankDim = higherRankShape[i];
57     lowerRankDim = lowerRankShape[j];
58 
59     if (lowerRankDim == 1 && higherRankDim > 1)
60       reshapeOutputShape[i] = 1;
61     else if ((lowerRankDim > 1 && higherRankDim == 1) ||
62              (lowerRankDim == higherRankDim))
63       reshapeOutputShape[i] = lowerRankDim;
64     else if (higherRankDim != lowerRankDim)
65       return failure();
66   }
67   return success();
68 }
69 
70 /// Common code to create the reshape op where necessary to make the rank of the
71 /// operations equal. Returns the updated input1 and input2 for the original
72 /// input. The caller is expected to use these to rewrite the original operator
73 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)74 static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter,
75                                           Location loc,
76                                           RankedTensorType outputType,
77                                           Value input1, Value input2,
78                                           Value &outInput1, Value &outInput2) {
79   auto input1Ty = input1.getType().dyn_cast<RankedTensorType>();
80   auto input2Ty = input2.getType().dyn_cast<RankedTensorType>();
81 
82   if (!input1Ty || !input2Ty)
83     return failure();
84 
85   int64_t input1Rank = input1Ty.getRank();
86   int64_t input2Rank = input2Ty.getRank();
87 
88   Value higherTensorValue, lowerTensorValue;
89   // Cannot rewrite as its already correct.
90   if (input1Rank == input2Rank)
91     return failure();
92 
93   if (input1Rank > input2Rank) {
94     higherTensorValue = input1;
95     lowerTensorValue = input2;
96   } else {
97     higherTensorValue = input2;
98     lowerTensorValue = input1;
99   }
100 
101   ArrayRef<int64_t> higherRankShape =
102       higherTensorValue.getType().cast<RankedTensorType>().getShape();
103   (void)higherRankShape;
104   ArrayRef<int64_t> lowerRankShape =
105       lowerTensorValue.getType().cast<RankedTensorType>().getShape();
106 
107   SmallVector<int64_t, 4> reshapeOutputShape;
108 
109   if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape)
110           .failed())
111     return failure();
112 
113   auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
114   auto reshapeOutputType = RankedTensorType::get(
115       ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
116 
117   // Verify the rank agrees with the output type if the output type is ranked.
118   if (outputType) {
119     if (outputType.getShape().size() != reshapeOutputShape.size() ||
120         outputType.getShape().size() != higherRankShape.size())
121       return failure();
122   }
123 
124   auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
125       loc, reshapeOutputType, lowerTensorValue,
126       rewriter.getI64ArrayAttr(reshapeOutputShape));
127 
128   if (input1Rank > input2Rank) {
129     outInput1 = higherTensorValue;
130     outInput2 = reshapeLower.getResult();
131   } else {
132     outInput1 = reshapeLower.getResult();
133     outInput2 = higherTensorValue;
134   }
135 
136   return success();
137 }
138 
139 namespace {
140 template <typename OpTy>
141 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
142   using OpRewritePattern<OpTy>::OpRewritePattern;
143 
matchAndRewrite__anon0064147e0111::ConvertTosaOp144   LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
145                                 PatternRewriter &rewriter) const override {
146 
147     Value input1 = tosaBinaryOp.getInput1();
148     Value input2 = tosaBinaryOp.getInput2();
149     Value output = tosaBinaryOp.getResult();
150 
151     auto outputType = output.getType().dyn_cast<RankedTensorType>();
152     if (!outputType)
153       return failure();
154 
155     Value outInput1, outInput2;
156     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
157                              input1, input2, outInput1, outInput2)
158             .failed())
159       return failure();
160 
161     rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
162                                       outInput2);
163 
164     return success();
165   }
166 };
167 
168 // The MulOp has an extra parameter 'shift' not present in other elementwise
169 // binary ops, that necessitates special handling of its builder.
170 template <>
171 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
172   using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
173 
matchAndRewrite__anon0064147e0111::ConvertTosaOp174   LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
175                                 PatternRewriter &rewriter) const override {
176 
177     Value input1 = tosaBinaryOp.getInput1();
178     Value input2 = tosaBinaryOp.getInput2();
179     int32_t shift = tosaBinaryOp.getShift();
180     Value output = tosaBinaryOp.getResult();
181     auto outputType = output.getType().dyn_cast<RankedTensorType>();
182     if (!outputType)
183       return failure();
184 
185     Value outInput1, outInput2;
186     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
187                              input1, input2, outInput1, outInput2)
188             .failed())
189       return failure();
190 
191     rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
192                                              outInput1, outInput2, shift);
193 
194     return success();
195   }
196 };
197 
198 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
199 // other elementwise binary ops, that necessitates special handling of its
200 // builder.
201 template <>
202 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
203     : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
204   using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
205 
matchAndRewrite__anon0064147e0111::ConvertTosaOp206   LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
207                                 PatternRewriter &rewriter) const override {
208 
209     Value input1 = tosaBinaryOp.getInput1();
210     Value input2 = tosaBinaryOp.getInput2();
211     int32_t round = tosaBinaryOp.getRound();
212     Value output = tosaBinaryOp.getResult();
213     auto outputType = output.getType().dyn_cast<RankedTensorType>();
214     if (!outputType)
215       return failure();
216 
217     Value outInput1, outInput2;
218     if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
219                              input1, input2, outInput1, outInput2)
220             .failed())
221       return failure();
222 
223     rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
224         tosaBinaryOp, outputType, outInput1, outInput2, round);
225 
226     return success();
227   }
228 };
229 } // namespace
230 
231 namespace {
232 /// Pass that enables broadcast by making all input arrays have the same
233 /// number of dimensions. Insert RESHAPE operations to lower rank operand
234 struct TosaMakeBroadcastable
235     : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
236 public:
runOnOperation__anon0064147e0211::TosaMakeBroadcastable237   void runOnOperation() override {
238     auto func = getOperation();
239     RewritePatternSet patterns(func.getContext());
240     MLIRContext *ctx = func.getContext();
241     // Add the generated patterns to the list.
242     patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx);
243     patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx);
244     patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx);
245     patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx);
246     patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx);
247     patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx);
248     patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx);
249     patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx);
250     patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx);
251     patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx);
252     patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx);
253     patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
254     patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
255     patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
256     patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
257     patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx);
258     patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx);
259     patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
260     patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
261     (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
262   }
263 };
264 } // namespace
265 
createTosaMakeBroadcastablePass()266 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
267   return std::make_unique<TosaMakeBroadcastable>();
268 }
269