1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Utils/Utils.h"
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21 #include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
22 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 #include <numeric>
29 
30 using namespace mlir;
31 using namespace mlir::tosa;
32 
33 template <typename T>
34 static arith::ConstantOp
createConstFromIntAttribute(Operation * op,const std::string & attrName,Type requiredAttrType,OpBuilder & rewriter)35 createConstFromIntAttribute(Operation *op, const std::string &attrName,
36                             Type requiredAttrType, OpBuilder &rewriter) {
37   auto castedN = static_cast<T>(
38       op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
39   return rewriter.create<arith::ConstantOp>(
40       op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
41 }
42 
43 static Value
createLinalgBodyCalculationForElementwiseOp(Operation * op,ValueRange args,ArrayRef<Type> resultTypes,PatternRewriter & rewriter)44 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
45                                             ArrayRef<Type> resultTypes,
46                                             PatternRewriter &rewriter) {
47   Location loc = op->getLoc();
48   auto elementTy =
49       op->getOperand(0).getType().cast<ShapedType>().getElementType();
50 
51   // tosa::AbsOp
52   if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
53     return rewriter.create<math::AbsOp>(loc, resultTypes, args);
54 
55   if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
56     auto zero = rewriter.create<arith::ConstantOp>(
57         loc, rewriter.getZeroAttr(elementTy));
58     auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
59                                               args[0], zero);
60     auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
61     return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
62   }
63 
64   // tosa::AddOp
65   if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
66     return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
67 
68   if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
69     return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
70 
71   // tosa::SubOp
72   if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
73     return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
74 
75   if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
76     return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
77 
78   // tosa::MulOp
79   if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
80     if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
81       (void)rewriter.notifyMatchFailure(op,
82                                         "Cannot have shift value for float");
83       return nullptr;
84     }
85     return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
86   }
87 
88   // tosa::DivOp
89   if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
90     return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
91 
92   // tosa::ReciprocalOp
93   if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
94     auto one =
95         rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
96     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
97   }
98 
99   if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
100     Value a = args[0];
101     Value b = args[1];
102     auto shift =
103         op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
104     if (shift > 0) {
105       auto shiftConst =
106           rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
107       if (!a.getType().isInteger(32))
108         a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
109 
110       if (!b.getType().isInteger(32))
111         b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
112 
113       auto result = rewriter.create<tosa::ApplyScaleOp>(
114           loc, rewriter.getI32Type(), a, b, shiftConst,
115           rewriter.getBoolAttr(false));
116 
117       if (elementTy.isInteger(32))
118         return result;
119 
120       return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
121     }
122 
123     int aWidth = a.getType().getIntOrFloatBitWidth();
124     int bWidth = b.getType().getIntOrFloatBitWidth();
125     int cWidth = resultTypes[0].getIntOrFloatBitWidth();
126 
127     if (aWidth < cWidth)
128       a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
129     if (bWidth < cWidth)
130       b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
131 
132     return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
133   }
134 
135   // tosa::NegateOp
136   if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
137     return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
138 
139   if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
140       !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
141     auto constant =
142         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
143     return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
144   }
145 
146   if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
147       cast<tosa::NegateOp>(op).getQuantizationInfo()) {
148     auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
149     int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
150     int64_t inZp = quantizationInfo.value().getInputZp();
151     int64_t outZp = quantizationInfo.value().getOutputZp();
152 
153     // Compute the maximum value that can occur in the intermediate buffer.
154     int64_t zpAdd = inZp + outZp;
155     int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
156                        std::abs(zpAdd) + 1;
157 
158     // Convert that maximum value into the maximum bitwidth needed to represent
159     // it. We assume 48-bit numbers may be supported further in the pipeline.
160     int intermediateBitWidth = 64;
161     if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162       intermediateBitWidth = 16;
163     } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164       intermediateBitWidth = 32;
165     } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166       intermediateBitWidth = 48;
167     }
168 
169     Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
170     Value zpAddValue = rewriter.create<arith::ConstantOp>(
171         loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
172 
173     // The negation can be applied by doing:
174     //  outputValue = inZp + outZp - inputValue
175     auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176     auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
177 
178     // Clamp to the negation range.
179     auto min = rewriter.create<arith::ConstantIntOp>(
180         loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
181         intermediateType);
182     auto max = rewriter.create<arith::ConstantIntOp>(
183         loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
184         intermediateType);
185     auto clamp = clampHelper<arith::CmpIOp>(
186         loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
187 
188     // Truncate to the final value.
189     return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
190   }
191 
192   // tosa::BitwiseAndOp
193   if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
194     return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
195 
196   // tosa::BitwiseOrOp
197   if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
198     return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
199 
200   // tosa::BitwiseNotOp
201   if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
202     auto allOnesAttr = rewriter.getIntegerAttr(
203         elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204     auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
206   }
207 
208   // tosa::BitwiseXOrOp
209   if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
210     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
211 
212   // tosa::LogicalLeftShiftOp
213   if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
214     return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
215 
216   // tosa::LogicalRightShiftOp
217   if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
218     return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
219 
220   // tosa::ArithmeticRightShiftOp
221   if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
222     auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223     auto round = op->getAttr("round").cast<BoolAttr>().getValue();
224     if (!round) {
225       return result;
226     }
227 
228     Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229     auto one =
230         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231     auto zero =
232         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233     auto i1one =
234         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
235 
236     // Checking that input2 != 0
237     auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238         loc, arith::CmpIPredicate::sgt, args[1], zero);
239 
240     // Checking for the last bit of input1 to be 1
241     auto subtract =
242         rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243     auto shifted =
244         rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245             ->getResults();
246     auto truncated =
247         rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
248     auto isInputOdd =
249         rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
250 
251     auto shouldRound = rewriter.create<arith::AndIOp>(
252         loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253     auto extended =
254         rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255     return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
256   }
257 
258   // tosa::ClzOp
259   if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
260     return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
261   }
262 
263   // tosa::LogicalAnd
264   if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265     return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
266 
267   // tosa::LogicalNot
268   if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269     auto one = rewriter.create<arith::ConstantOp>(
270         loc, rewriter.getIntegerAttr(elementTy, 1));
271     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
272   }
273 
274   // tosa::LogicalOr
275   if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276     return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
277 
278   // tosa::LogicalXor
279   if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280     return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
281 
282   // tosa::PowOp
283   if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
284     return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
285 
286   // tosa::RsqrtOp
287   if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
288     return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
289 
290   // tosa::LogOp
291   if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
292     return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
293 
294   // tosa::ExpOp
295   if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
296     return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
297 
298   // tosa::TanhOp
299   if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
300     return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
301 
302   // tosa::GreaterOp
303   if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
304     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
305                                           args[0], args[1]);
306 
307   if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
308     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
309                                           args[0], args[1]);
310 
311   // tosa::GreaterEqualOp
312   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
313     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
314                                           args[0], args[1]);
315 
316   if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
317     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
318                                           args[0], args[1]);
319 
320   // tosa::EqualOp
321   if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
322     return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
323                                           args[0], args[1]);
324 
325   if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
326     return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
327                                           args[0], args[1]);
328 
329   // tosa::SelectOp
330   if (isa<tosa::SelectOp>(op)) {
331     elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
332     if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
333       return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
334   }
335 
336   // tosa::MaximumOp
337   if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
338     auto predicate = rewriter.create<arith::CmpFOp>(
339         loc, arith::CmpFPredicate::OGT, args[0], args[1]);
340     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
341   }
342 
343   if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
344     auto predicate = rewriter.create<arith::CmpIOp>(
345         loc, arith::CmpIPredicate::sgt, args[0], args[1]);
346     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
347   }
348 
349   // tosa::MinimumOp
350   if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
351     auto predicate = rewriter.create<arith::CmpFOp>(
352         loc, arith::CmpFPredicate::OLT, args[0], args[1]);
353     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
354   }
355 
356   if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
357     auto predicate = rewriter.create<arith::CmpIOp>(
358         loc, arith::CmpIPredicate::slt, args[0], args[1]);
359     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
360   }
361 
362   // tosa::CeilOp
363   if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
364     return rewriter.create<math::CeilOp>(loc, resultTypes, args);
365 
366   // tosa::FloorOp
367   if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
368     return rewriter.create<math::FloorOp>(loc, resultTypes, args);
369 
370   // tosa::ClampOp
371   if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
372     bool losesInfo = false;
373     APFloat min_apf = op->getAttr("min_fp").cast<FloatAttr>().getValue();
374     APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
375     min_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
376                     APFloat::rmNearestTiesToEven, &losesInfo);
377     max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
378                     APFloat::rmNearestTiesToEven, &losesInfo);
379     auto min = rewriter.create<arith::ConstantOp>(
380         loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
381     auto max = rewriter.create<arith::ConstantOp>(
382         loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
383     return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
384                                       arith::CmpFPredicate::OLT, rewriter);
385   }
386 
387   if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
388     auto intTy = elementTy.cast<IntegerType>();
389     int32_t min = static_cast<int32_t>(
390         op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
391     int32_t max = static_cast<int32_t>(
392         op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
393 
394     if (intTy.isUnsignedInteger()) {
395       min = std::max<int32_t>(min, 0);
396       max = std::min<int32_t>(
397           max,
398           APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
399     } else {
400       min = std::max<int32_t>(
401           min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
402                    .getSExtValue());
403       max = std::min<int32_t>(
404           max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
405                    .getSExtValue());
406     }
407 
408     auto minVal = rewriter.create<arith::ConstantIntOp>(
409         loc, min, intTy.getIntOrFloatBitWidth());
410     auto maxVal = rewriter.create<arith::ConstantIntOp>(
411         loc, max, intTy.getIntOrFloatBitWidth());
412     return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
413                                       arith::CmpIPredicate::slt, rewriter);
414   }
415 
416   // tosa::ReluNOp
417   if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
418     auto zero =
419         rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
420     bool losesInfo = false;
421     APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
422     max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
423                     APFloat::rmNearestTiesToEven, &losesInfo);
424     auto n = rewriter.create<arith::ConstantOp>(
425         loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
426     return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
427                                       arith::CmpFPredicate::OLT, rewriter);
428   }
429 
430   if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
431     auto zero =
432         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
433     auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
434                                                   rewriter);
435     return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
436                                       arith::CmpIPredicate::slt, rewriter);
437   }
438 
439   // tosa::SigmoidOp
440   if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
441     auto one =
442         rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
443     auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
444     auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
445     auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
446     return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
447   }
448 
449   // tosa::CastOp
450   if (isa<tosa::CastOp>(op)) {
451     Type srcTy = elementTy;
452     Type dstTy = resultTypes.front();
453     bool bitExtend =
454         srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
455 
456     if (srcTy == dstTy)
457       return args.front();
458 
459     if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
460       return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
461 
462     if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
463       return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
464                                               mlir::None);
465 
466     // 1-bit integers need to be treated as signless.
467     if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
468       return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
469                                               mlir::None);
470 
471     if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
472       return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
473                                              mlir::None);
474 
475     // Unsigned integers need an unrealized cast so that they can be passed
476     // to UIToFP.
477     if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
478       auto unrealizedCast =
479           rewriter
480               .create<UnrealizedConversionCastOp>(
481                   loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
482                   args[0])
483               .getResult(0);
484       return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
485                                               unrealizedCast);
486     }
487 
488     // All other si-to-fp conversions should be handled by SIToFP.
489     if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
490       return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
491                                               mlir::None);
492 
493     // Casting to boolean, floats need to only be checked as not-equal to zero.
494     if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
495       Value zero = rewriter.create<arith::ConstantOp>(
496           loc, rewriter.getFloatAttr(srcTy, 0.0));
497       return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
498                                             args.front(), zero);
499     }
500 
501     if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
502       auto zero = rewriter.create<arith::ConstantOp>(
503           loc, rewriter.getF32FloatAttr(0.0f));
504       auto half = rewriter.create<arith::ConstantOp>(
505           loc, rewriter.getF32FloatAttr(0.5f));
506 
507       auto intMin = rewriter.create<arith::ConstantOp>(
508           loc, rewriter.getF32FloatAttr(
509                    APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
510                        .getSExtValue()));
511 
512       auto intMax = rewriter.create<arith::ConstantOp>(
513           loc, rewriter.getF32FloatAttr(
514                    APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
515                        .getSExtValue()));
516 
517       auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
518       auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
519       auto negative = rewriter.create<arith::CmpFOp>(
520           loc, arith::CmpFPredicate::OLT, args[0], zero);
521       auto rounded =
522           rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
523 
524       auto clamped = clampHelper<arith::CmpFOp>(
525           loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
526 
527       return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
528     }
529 
530     // Casting to boolean, integers need to only be checked as not-equal to
531     // zero.
532     if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
533       Value zero = rewriter.create<arith::ConstantIntOp>(
534           loc, 0, srcTy.getIntOrFloatBitWidth());
535       return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
536                                             args.front(), zero);
537     }
538 
539     if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
540       return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
541                                              mlir::None);
542 
543     if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
544       auto intMin = rewriter.create<arith::ConstantIntOp>(
545           loc,
546           APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
547               .getSExtValue(),
548           srcTy.getIntOrFloatBitWidth());
549 
550       auto intMax = rewriter.create<arith::ConstantIntOp>(
551           loc,
552           APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
553               .getSExtValue(),
554           srcTy.getIntOrFloatBitWidth());
555 
556       auto clamped = clampHelper<arith::CmpIOp>(
557           loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
558       return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
559     }
560   }
561 
562   (void)rewriter.notifyMatchFailure(
563       op, "unhandled op for linalg body calculation for elementwise op");
564   return nullptr;
565 }
566 
567 static LogicalResult
elementwiseMatchAndRewriteHelper(Operation * operation,PatternRewriter & rewriter)568 elementwiseMatchAndRewriteHelper(Operation *operation,
569                                  PatternRewriter &rewriter) {
570   auto loc = operation->getLoc();
571 
572   assert(operation->getNumResults() == 1 &&
573          "All TOSA elementwise ops should only return a single result.");
574 
575   auto results = operation->getResults();
576   auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
577 
578   if (!resultTy)
579     return rewriter.notifyMatchFailure(operation,
580                                        "All results must be a shaped type");
581 
582   unsigned rank = resultTy.getRank();
583 
584   // Construct the indexing maps needed for linalg.generic ops.
585   SmallVector<Type> bodyArgTypes;
586 
587   for (Value in : operation->getOperands())
588     bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
589 
590   SmallVector<Type> opResultTypes;
591   SmallVector<Value> initTensors;
592 
593   SmallVector<Value> dynDims;
594   dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
595 
596   for (auto arg : operation->getOperands()) {
597     auto operandTy = arg.getType().cast<ShapedType>();
598     for (int i = 0; i < operandTy.getRank(); i++) {
599       if (operandTy.isDynamicDim(i) && !dynDims[i])
600         dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
601     }
602   }
603 
604   SmallVector<Value> filteredDims = condenseValues(dynDims);
605 
606   for (auto result : results) {
607     auto resultTy = result.getType().template cast<ShapedType>();
608     initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
609         loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
610     opResultTypes.push_back(result.getType());
611   }
612 
613   auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
614       initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
615 
616   SmallVector<Value, 2> operands;
617   SmallVector<AffineMap, 2> indexingMaps;
618   indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
619 
620   // Input indexing maps may be broadcasted.
621   for (Value operand : operation->getOperands()) {
622     ShapedType type = operand.getType().cast<ShapedType>();
623 
624     if (type.getShape() == resultTy.getShape()) {
625       operands.push_back(operand);
626       indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
627       continue;
628     }
629 
630     SmallVector<int64_t, 5> newShape;
631     SmallVector<AffineExpr, 4> affineExprs;
632     newShape.reserve(type.getRank());
633     for (const auto &it : llvm::enumerate(type.getShape())) {
634       if (it.value() == resultTy.getDimSize(it.index())) {
635         newShape.push_back(it.value());
636         affineExprs.push_back(
637             mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
638       }
639     }
640 
641     if (newShape.size() != rank) {
642       operand = rewriter.create<tosa::ReshapeOp>(
643           loc, RankedTensorType::get(newShape, type.getElementType()), operand,
644           rewriter.getI64ArrayAttr(newShape));
645     }
646 
647     operands.push_back(operand);
648     indexingMaps.push_back(AffineMap::get(
649         /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
650         rewriter.getContext()));
651   }
652 
653   indexingMaps.append(operation->getNumResults(),
654                       rewriter.getMultiDimIdentityMap(rank));
655 
656   bool didEncounterError = false;
657   auto linalgOp = rewriter.create<linalg::GenericOp>(
658       loc, opResultTypes, operands, initTensors, indexingMaps,
659       getNParallelLoopsAttrs(rank),
660       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
661         Value opResult = createLinalgBodyCalculationForElementwiseOp(
662             operation, blockArgs.take_front(operation->getNumOperands()),
663             bodyResultTypes, rewriter);
664         if (!opResult) {
665           didEncounterError = true;
666           return;
667         }
668         nestedBuilder.create<linalg::YieldOp>(loc, opResult);
669       });
670 
671   if (didEncounterError)
672     return failure();
673 
674   rewriter.replaceOp(operation, linalgOp->getResults());
675   return success();
676 }
677 
678 // Returns the constant initial value for a given reduction operation. The
679 // attribute type varies depending on the element type required.
createInitialValueForReduceOp(Operation * op,Type elementTy,PatternRewriter & rewriter)680 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
681                                                PatternRewriter &rewriter) {
682   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
683     return rewriter.getFloatAttr(elementTy, 0.0);
684 
685   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
686     return rewriter.getIntegerAttr(elementTy, 0);
687 
688   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
689     return rewriter.getFloatAttr(elementTy, 1.0);
690 
691   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
692     return rewriter.getIntegerAttr(elementTy, 1);
693 
694   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
695     return rewriter.getFloatAttr(
696         elementTy, APFloat::getLargest(
697                        elementTy.cast<FloatType>().getFloatSemantics(), false));
698 
699   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
700     return rewriter.getIntegerAttr(
701         elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
702 
703   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
704     return rewriter.getFloatAttr(
705         elementTy, APFloat::getLargest(
706                        elementTy.cast<FloatType>().getFloatSemantics(), true));
707 
708   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
709     return rewriter.getIntegerAttr(
710         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
711 
712   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
713     return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
714 
715   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
716     return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
717 
718   if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
719     return rewriter.getFloatAttr(
720         elementTy, APFloat::getLargest(
721                        elementTy.cast<FloatType>().getFloatSemantics(), true));
722 
723   if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
724     return rewriter.getIntegerAttr(
725         elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
726 
727   return {};
728 }
729 
730 // Creates the body calculation for a reduction. The operations vary depending
731 // on the input type.
createLinalgBodyCalculationForReduceOp(Operation * op,ValueRange args,Type elementTy,PatternRewriter & rewriter)732 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
733                                                     ValueRange args,
734                                                     Type elementTy,
735                                                     PatternRewriter &rewriter) {
736   Location loc = op->getLoc();
737   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
738     return rewriter.create<arith::AddFOp>(loc, args);
739   }
740 
741   if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
742     return rewriter.create<arith::AddIOp>(loc, args);
743   }
744 
745   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
746     return rewriter.create<arith::MulFOp>(loc, args);
747   }
748 
749   if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
750     return rewriter.create<arith::MulIOp>(loc, args);
751   }
752 
753   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
754     auto predicate = rewriter.create<arith::CmpFOp>(
755         loc, arith::CmpFPredicate::OLT, args[0], args[1]);
756     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
757   }
758 
759   if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
760     auto predicate = rewriter.create<arith::CmpIOp>(
761         loc, arith::CmpIPredicate::slt, args[0], args[1]);
762     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
763   }
764 
765   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
766     auto predicate = rewriter.create<arith::CmpFOp>(
767         loc, arith::CmpFPredicate::OGT, args[0], args[1]);
768     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
769   }
770 
771   if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
772     auto predicate = rewriter.create<arith::CmpIOp>(
773         loc, arith::CmpIPredicate::sgt, args[0], args[1]);
774     return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
775   }
776 
777   if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
778     return rewriter.create<arith::AndIOp>(loc, args);
779 
780   if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
781     return rewriter.create<arith::OrIOp>(loc, args);
782 
783   return {};
784 }
785 
786 // Performs the match and rewrite for reduction operations. This includes
787 // declaring a correctly sized initial value, and the linalg.generic operation
788 // that reduces across the specified axis.
reduceMatchAndRewriteHelper(Operation * op,uint64_t axis,PatternRewriter & rewriter)789 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
790                                                  PatternRewriter &rewriter) {
791   auto loc = op->getLoc();
792   auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
793   auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
794   auto elementTy = resultTy.getElementType();
795   Value input = op->getOperand(0);
796 
797   llvm::SmallVector<int64_t> reduceShape;
798   SmallVector<Value> dynDims;
799   for (unsigned i = 0; i < inputTy.getRank(); i++) {
800     if (axis != i) {
801       reduceShape.push_back(inputTy.getDimSize(i));
802       if (inputTy.isDynamicDim(i))
803         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
804     }
805   }
806 
807   Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
808 
809   // First fill the output buffer with the init value.
810   auto initTensor = rewriter
811                         .create<linalg::InitTensorOp>(loc, dynDims, reduceShape,
812                                                       resultTy.getElementType())
813                         .result();
814 
815   auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
816   if (!fillValueAttr)
817     return rewriter.notifyMatchFailure(
818         op, "No initial value found for reduction operation");
819 
820   auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
821   auto filledTensor = rewriter
822                           .create<linalg::FillOp>(loc, ValueRange{fillValue},
823                                                   ValueRange{initTensor})
824                           .result();
825 
826   SmallVector<AffineExpr, 2> srcExprs;
827   SmallVector<AffineExpr, 2> dstExprs;
828   SmallVector<StringRef, 4> iteratorTypes;
829   for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
830     srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
831 
832     iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
833                                       : getParallelIteratorTypeName());
834     if (axis != i)
835       dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
836   }
837 
838   bool didEncounterError = false;
839   auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
840   auto linalgOp = rewriter.create<linalg::GenericOp>(
841       loc, reduceTy, input, filledTensor, maps, iteratorTypes,
842       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
843         auto result = createLinalgBodyCalculationForReduceOp(
844             op, blockArgs, elementTy, rewriter);
845         if (result)
846           didEncounterError = true;
847 
848         nestedBuilder.create<linalg::YieldOp>(loc, result);
849       });
850 
851   if (!didEncounterError)
852     return failure();
853 
854   rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
855                                                linalgOp.getResults());
856   return success();
857 }
858 
findIntermediateShape(ArrayRef<int64_t> lhsShape,ArrayRef<int64_t> rhsShape,SmallVector<int64_t> & intermediateShape,bool isDynamic)859 static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
860                                   ArrayRef<int64_t> rhsShape,
861                                   SmallVector<int64_t> &intermediateShape,
862                                   bool isDynamic) {
863   if (isDynamic) {
864     // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
865     intermediateShape = {-1};
866     return true;
867   }
868 
869   if (lhsShape.empty() || rhsShape.empty()) {
870     intermediateShape = {};
871     return true;
872   }
873 
874   unsigned currLhsDim = 0, currRhsDim = 0;
875   while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
876     int64_t rhsSize = rhsShape[currRhsDim];
877     int64_t lhsSize = lhsShape[currLhsDim];
878     while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
879            currRhsDim < rhsShape.size()) {
880       if (lhsSize < rhsSize) {
881         currLhsDim++;
882         lhsSize *= lhsShape[currLhsDim];
883       } else {
884         currRhsDim++;
885         rhsSize *= rhsShape[currRhsDim];
886       }
887     }
888     if (lhsSize == rhsSize) {
889       intermediateShape.push_back(lhsSize);
890     }
891     currRhsDim++;
892     currLhsDim++;
893   }
894 
895   // If the iterators didn't reach the end and their leftover dimensions are not
896   // equal to 1 an intermediate shape was not found.
897   while (currLhsDim < lhsShape.size()) {
898     if (lhsShape[currLhsDim++] != 1) {
899       return false;
900     }
901   }
902 
903   while (currRhsDim < rhsShape.size()) {
904     if (rhsShape[currRhsDim++] != 1) {
905       return false;
906     }
907   }
908 
909   return true;
910 }
911 
createReassociationMapsForCollapse(PatternRewriter & rewriter,ArrayRef<int64_t> srcShape,ArrayRef<int64_t> dstShape,SmallVector<ReassociationExprs,4> & reassociationMap,bool isDynamic)912 static bool createReassociationMapsForCollapse(
913     PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
914     ArrayRef<int64_t> dstShape,
915     SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
916 
917   // If the shape is dynamic, create a map for collapsing into one dimension.
918   if (isDynamic) {
919     SmallVector<AffineExpr, 2> exprs;
920     for (int i = 0, s = srcShape.size(); i < s; ++i)
921       exprs.push_back(rewriter.getAffineDimExpr(i));
922     reassociationMap = {exprs};
923     return true;
924   }
925 
926   if (dstShape.empty()) {
927     reassociationMap = {};
928     return true;
929   }
930 
931   reassociationMap.resize(dstShape.size());
932   unsigned currSrcDim = 0, currDstDim = 0;
933   while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
934     int64_t dstSize = dstShape[currDstDim];
935     int64_t srcSize = srcShape[currSrcDim];
936     while (srcSize < dstSize && currSrcDim < srcShape.size()) {
937       reassociationMap[currDstDim].push_back(
938           rewriter.getAffineDimExpr(currSrcDim++));
939       srcSize *= srcShape[currSrcDim];
940     }
941     if (srcSize == dstSize) {
942       reassociationMap[currDstDim].push_back(
943           rewriter.getAffineDimExpr(currSrcDim++));
944       // If the next dim in collapsedShape is not 1, treat subsequent dims in
945       // expandedShape which are 1 to be collapsed.
946       if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
947         while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
948           reassociationMap[currDstDim].push_back(
949               rewriter.getAffineDimExpr(currSrcDim++));
950         }
951       }
952     }
953     currDstDim++;
954   }
955 
956   // If both iterators didn't reach the end, we have leftover dimentions which
957   // implies that we have a mismatch in shape.
958   return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
959 }
960 
961 namespace {
962 
963 template <typename SrcOp>
964 class PointwiseConverter : public OpRewritePattern<SrcOp> {
965 public:
966   using OpRewritePattern<SrcOp>::OpRewritePattern;
967 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const968   LogicalResult matchAndRewrite(SrcOp op,
969                                 PatternRewriter &rewriter) const final {
970     return elementwiseMatchAndRewriteHelper(op, rewriter);
971   }
972 };
973 
974 class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
975 public:
976   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
977 
978   LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const979   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
980                   ConversionPatternRewriter &rewriter) const final {
981     ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
982     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
983     bool isDynamic = !operandTy.hasStaticShape();
984 
985     if (isDynamic && resultTy.getRank() != 1) {
986       return rewriter.notifyMatchFailure(
987           reshape, "Cannot collapse dynamic dims to more than one dimension");
988     }
989 
990     if (operandTy == resultTy) {
991       rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
992       return success();
993     }
994 
995     SmallVector<ReassociationExprs, 4> reassociationMap;
996     if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
997                                             resultTy.getShape(),
998                                             reassociationMap, isDynamic)) {
999       return rewriter.notifyMatchFailure(
1000           reshape,
1001           "tosa.reshape Attempting to collapse into an incompatible shape");
1002     }
1003 
1004     SmallVector<int64_t> intermediateShape;
1005     if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1006                                intermediateShape, isDynamic)) {
1007       return rewriter.notifyMatchFailure(
1008           reshape, "tosa.reshape Cannot collapse into given shape");
1009     }
1010 
1011     rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1012         reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1013     return success();
1014   }
1015 };
1016 
1017 class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
1018 public:
1019   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1020 
1021   LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1022   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1023                   ConversionPatternRewriter &rewriter) const final {
1024     ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
1025     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1026     bool isDynamic = !operandTy.hasStaticShape();
1027 
1028     if (operandTy == resultTy) {
1029       rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1030       return success();
1031     }
1032 
1033     if (isDynamic && operandTy.getRank() != 1) {
1034       return rewriter.notifyMatchFailure(
1035           reshape, "Cannot expand dynamic dims from more than one dimension");
1036     }
1037 
1038     SmallVector<ReassociationExprs, 4> reassociationMap;
1039     if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
1040                                             operandTy.getShape(),
1041                                             reassociationMap, isDynamic)) {
1042       return rewriter.notifyMatchFailure(
1043           reshape,
1044           "tosa.reshape Attempting to expand into an incompatible shape");
1045     }
1046 
1047     SmallVector<int64_t> intermediateShape;
1048     if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1049                                intermediateShape, isDynamic) ||
1050         intermediateShape != operandTy.getShape()) {
1051       return rewriter.notifyMatchFailure(
1052           reshape, "tosa.reshape Cannot expand into given shape");
1053     }
1054     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1055         reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1056     return success();
1057   }
1058 };
1059 
1060 class ReshapeConverterCollapseExpand
1061     : public OpConversionPattern<tosa::ReshapeOp> {
1062 public:
1063   using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1064 
1065   LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1066   matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1067                   ConversionPatternRewriter &rewriter) const final {
1068     ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
1069     ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1070     bool isDynamic = !operandTy.hasStaticShape();
1071 
1072     if (operandTy == resultTy) {
1073       rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1074       return success();
1075     }
1076 
1077     SmallVector<int64_t> intermediateShape;
1078     if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
1079                                intermediateShape, isDynamic)) {
1080       return rewriter.notifyMatchFailure(
1081           reshape, "tosa.reshape Cannot identify an intermediate shape between "
1082                    "the given two shapes");
1083     }
1084 
1085     Value collapse = rewriter.create<tosa::ReshapeOp>(
1086         reshape.getLoc(),
1087         RankedTensorType::get(intermediateShape,
1088                               reshape.getType().getElementType()),
1089         adaptor.getInput1());
1090     Value expand =
1091         rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
1092     rewriter.replaceOp(reshape, expand);
1093 
1094     return success();
1095   }
1096 };
1097 
1098 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1099 public:
1100   using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1101 
matchAndRewrite(tosa::TransposeOp op,PatternRewriter & rewriter) const1102   LogicalResult matchAndRewrite(tosa::TransposeOp op,
1103                                 PatternRewriter &rewriter) const final {
1104     DenseIntElementsAttr perms;
1105     if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
1106       return failure();
1107     }
1108 
1109     auto loc = op.getLoc();
1110     auto input = op->getOperand(0);
1111     auto resultTy = op.getType().cast<ShapedType>();
1112 
1113     SmallVector<Value> dynDims;
1114     dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1115 
1116     SmallVector<AffineExpr, 2> inputExprs;
1117     inputExprs.resize(resultTy.getRank());
1118     auto operandTy = input.getType().cast<ShapedType>();
1119     for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1120       auto index = permutation.index();
1121       auto value = permutation.value().getZExtValue();
1122       if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1123         dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1124       }
1125       inputExprs[value] = rewriter.getAffineDimExpr(index);
1126     }
1127 
1128     SmallVector<Value> filteredDims = condenseValues(dynDims);
1129 
1130     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1131         loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1132 
1133     SmallVector<AffineMap, 2> affineMaps = {
1134         AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1135                        rewriter.getContext()),
1136         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1137 
1138     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1139         op, resultTy, op.getInput1(), ValueRange{initTensor}, affineMaps,
1140         getNParallelLoopsAttrs(resultTy.getRank()),
1141         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1142           nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1143         });
1144     return success();
1145   }
1146 };
1147 
1148 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1149 public:
1150   using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1151 
matchAndRewrite(tosa::RescaleOp op,PatternRewriter & rewriter) const1152   LogicalResult matchAndRewrite(tosa::RescaleOp op,
1153                                 PatternRewriter &rewriter) const final {
1154     auto loc = op.getLoc();
1155     auto input = op.getInput();
1156     auto inputTy = op.getInput().getType().cast<ShapedType>();
1157     auto outputTy = op.getOutput().getType().cast<ShapedType>();
1158     unsigned rank = inputTy.getRank();
1159 
1160     // This is an illegal configuration. terminate and log an error
1161     if (op.getDoubleRound() && !op.getScale32())
1162       return rewriter.notifyMatchFailure(
1163           op, "tosa.rescale requires scale32 for double_round to be true");
1164 
1165     auto dynamicDimsOr =
1166         checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1167     if (!dynamicDimsOr.has_value())
1168       return failure();
1169     SmallVector<Value> dynamicDims = dynamicDimsOr.value();
1170 
1171     // The shift and multiplier values.
1172     SmallVector<int32_t> multiplierValues;
1173     getValuesFromIntArrayAttribute(op.getMultiplier(), multiplierValues);
1174 
1175     SmallVector<int8_t> shiftValues;
1176     getValuesFromIntArrayAttribute(op.getShift(), shiftValues);
1177 
1178     // If we shift by more than the bitwidth, this just sets to 0.
1179     for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1180       if (shiftValues[i] > 63) {
1181         shiftValues[i] = 0;
1182         multiplierValues[i] = 0;
1183       }
1184     }
1185 
1186     // Double round only occurs if shift is greater than 31, check that this
1187     // is ever true.
1188     bool doubleRound =
1189         op.getDoubleRound() &&
1190         llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1191 
1192     SmallVector<AffineMap> indexingMaps = {
1193         rewriter.getMultiDimIdentityMap(rank)};
1194     SmallVector<Value, 4> genericInputs = {input};
1195 
1196     // If we are rescaling per-channel then we need to store the multiplier
1197     // values in a buffer.
1198     Value multiplierConstant;
1199     int64_t multiplierArg = 0;
1200     if (multiplierValues.size() == 1) {
1201       multiplierConstant = rewriter.create<arith::ConstantOp>(
1202           loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1203     } else {
1204       SmallVector<AffineExpr, 2> multiplierExprs{
1205           rewriter.getAffineDimExpr(rank - 1)};
1206       auto multiplierType =
1207           RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1208                                 rewriter.getI32Type());
1209       genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1210           loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1211 
1212       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1213                                             /*symbolCount=*/0, multiplierExprs,
1214                                             rewriter.getContext()));
1215 
1216       multiplierArg = indexingMaps.size() - 1;
1217     }
1218 
1219     // If we are rescaling per-channel then we need to store the shift
1220     // values in a buffer.
1221     Value shiftConstant;
1222     int64_t shiftArg = 0;
1223     if (shiftValues.size() == 1) {
1224       shiftConstant = rewriter.create<arith::ConstantOp>(
1225           loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1226     } else {
1227       SmallVector<AffineExpr, 2> shiftExprs = {
1228           rewriter.getAffineDimExpr(rank - 1)};
1229       auto shiftType =
1230           RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1231                                 rewriter.getIntegerType(8));
1232       genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1233           loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1234       indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1235                                             /*symbolCount=*/0, shiftExprs,
1236                                             rewriter.getContext()));
1237       shiftArg = indexingMaps.size() - 1;
1238     }
1239 
1240     // Indexing maps for output values.
1241     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1242 
1243     // Construct the indexing maps needed for linalg.generic ops.
1244     Value initTensor = rewriter.create<linalg::InitTensorOp>(
1245         loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
1246 
1247     auto linalgOp = rewriter.create<linalg::GenericOp>(
1248         loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1249         getNParallelLoopsAttrs(rank),
1250         [&](OpBuilder &nestedBuilder, Location nestedLoc,
1251             ValueRange blockArgs) {
1252           Value value = blockArgs[0];
1253           Type valueTy = value.getType();
1254 
1255           // For now we do all of our math in 64-bit. This is not optimal but
1256           // should be correct for now, consider computing correct bit depth
1257           // later.
1258           int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1259 
1260           auto inputZp = createConstFromIntAttribute<int32_t>(
1261               op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1262               nestedBuilder);
1263           auto outputZp = createConstFromIntAttribute<int32_t>(
1264               op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1265 
1266           Value multiplier = multiplierConstant ? multiplierConstant
1267                                                 : blockArgs[multiplierArg];
1268           Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1269 
1270           if (valueTy.getIntOrFloatBitWidth() < 32) {
1271             if (valueTy.isUnsignedInteger()) {
1272               value = nestedBuilder
1273                           .create<UnrealizedConversionCastOp>(
1274                               nestedLoc,
1275                               nestedBuilder.getIntegerType(
1276                                   valueTy.getIntOrFloatBitWidth()),
1277                               value)
1278                           .getResult(0);
1279               value = nestedBuilder.create<arith::ExtUIOp>(
1280                   nestedLoc, nestedBuilder.getI32Type(), value);
1281             } else {
1282               value = nestedBuilder.create<arith::ExtSIOp>(
1283                   nestedLoc, nestedBuilder.getI32Type(), value);
1284             }
1285           }
1286 
1287           value =
1288               nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1289 
1290           value = nestedBuilder.create<tosa::ApplyScaleOp>(
1291               loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1292               nestedBuilder.getBoolAttr(doubleRound));
1293 
1294           // Move to the new zero-point.
1295           value =
1296               nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1297 
1298           // Saturate to the output size.
1299           IntegerType outIntType =
1300               blockArgs.back().getType().cast<IntegerType>();
1301           unsigned outBitWidth = outIntType.getWidth();
1302 
1303           int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1304           int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1305 
1306           // Unsigned integers have a difference output value.
1307           if (outIntType.isUnsignedInteger()) {
1308             intMin = 0;
1309             intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1310           }
1311 
1312           auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1313               loc, nestedBuilder.getI32IntegerAttr(intMin));
1314           auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1315               loc, nestedBuilder.getI32IntegerAttr(intMax));
1316 
1317           value = clampHelper<arith::CmpIOp>(
1318               nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
1319               nestedBuilder);
1320 
1321           if (outIntType.getWidth() < 32) {
1322             value = nestedBuilder.create<arith::TruncIOp>(
1323                 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1324                 value);
1325 
1326             if (outIntType.isUnsignedInteger()) {
1327               value = nestedBuilder
1328                           .create<UnrealizedConversionCastOp>(nestedLoc,
1329                                                               outIntType, value)
1330                           .getResult(0);
1331             }
1332           }
1333 
1334           nestedBuilder.create<linalg::YieldOp>(loc, value);
1335         });
1336 
1337     rewriter.replaceOp(op, linalgOp->getResults());
1338     return success();
1339   }
1340 };
1341 
1342 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1343 public:
1344   using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1345 
matchAndRewrite(tosa::ResizeOp op,PatternRewriter & rewriter) const1346   LogicalResult matchAndRewrite(tosa::ResizeOp op,
1347                                 PatternRewriter &rewriter) const final {
1348     Location loc = op.getLoc();
1349     auto input = op.getInput();
1350     auto inputTy = input.getType().cast<ShapedType>();
1351     auto resultTy = op.getType().cast<ShapedType>();
1352     auto resultElementTy = resultTy.getElementType();
1353 
1354     auto imageH = inputTy.getShape()[1];
1355     auto imageW = inputTy.getShape()[2];
1356 
1357     auto dynamicDimsOr =
1358         checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1359     if (!dynamicDimsOr.has_value())
1360       return failure();
1361     SmallVector<Value> dynamicDims = dynamicDimsOr.value();
1362 
1363     if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1364       return failure();
1365 
1366     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1367         loc, dynamicDims, resultTy.getShape(), resultElementTy);
1368 
1369     SmallVector<AffineMap, 2> affineMaps = {
1370         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1371 
1372     auto genericOp = rewriter.create<linalg::GenericOp>(
1373         loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1374         getNParallelLoopsAttrs(resultTy.getRank()));
1375     rewriter.replaceOp(op, genericOp.getResult(0));
1376 
1377     OpBuilder::InsertionGuard regionGuard(rewriter);
1378     rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1379                          TypeRange({resultElementTy}), loc);
1380     Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1381     Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1382     Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1383     Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1384 
1385     auto hwMin =
1386         rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1387     auto hMax = rewriter.create<arith::ConstantOp>(
1388         loc, rewriter.getI32IntegerAttr(imageH - 1));
1389     auto wMax = rewriter.create<arith::ConstantOp>(
1390         loc, rewriter.getI32IntegerAttr(imageW - 1));
1391 
1392     Value inY =
1393         rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
1394     Value inX =
1395         rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
1396 
1397     int32_t shift = op.getShift();
1398     bool floatingPointMode = shift == 0;
1399 
1400     Value yStride, xStride, yOffset, xOffset;
1401     if (floatingPointMode) {
1402       yStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[0]);
1403       xStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[1]);
1404       yOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[0]);
1405       xOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[1]);
1406     } else {
1407       SmallVector<int32_t> stride, offset;
1408       getValuesFromIntArrayAttribute(op.getStride(), stride);
1409       getValuesFromIntArrayAttribute(op.getOffset(), offset);
1410 
1411       yStride = rewriter.create<arith::ConstantOp>(
1412           loc, rewriter.getI32IntegerAttr(stride[0]));
1413       xStride = rewriter.create<arith::ConstantOp>(
1414           loc, rewriter.getI32IntegerAttr(stride[1]));
1415       yOffset = rewriter.create<arith::ConstantOp>(
1416           loc, rewriter.getI32IntegerAttr(offset[0]));
1417       xOffset = rewriter.create<arith::ConstantOp>(
1418           loc, rewriter.getI32IntegerAttr(offset[1]));
1419     }
1420 
1421     // Compute the the integer index and partial offset.
1422     // x = x * stride + offset;
1423     // ix = floor(x)
1424     // dx = x - ix
1425     Value ix, iy, dx, dy;
1426     if (floatingPointMode) {
1427       Value y =
1428           rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
1429       Value x =
1430           rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
1431 
1432       y = rewriter.create<arith::MulFOp>(loc, y, yStride);
1433       x = rewriter.create<arith::MulFOp>(loc, x, xStride);
1434 
1435       y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
1436       x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
1437 
1438       iy = rewriter.create<math::FloorOp>(loc, y);
1439       ix = rewriter.create<math::FloorOp>(loc, x);
1440 
1441       dy = rewriter.create<arith::SubFOp>(loc, y, iy);
1442       dx = rewriter.create<arith::SubFOp>(loc, x, ix);
1443 
1444       iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
1445       ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
1446     } else {
1447       Value shiftVal = rewriter.create<arith::ConstantOp>(
1448           loc, rewriter.getI32IntegerAttr(shift));
1449 
1450       Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
1451       Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
1452 
1453       y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
1454       x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
1455 
1456       iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
1457       ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
1458 
1459       Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
1460       Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
1461 
1462       dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
1463       dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
1464     }
1465 
1466     if (op.getMode() == "NEAREST_NEIGHBOR") {
1467       Value yPred, xPred;
1468       // Round the index position towards the closest pixel location.
1469       if (floatingPointMode) {
1470         auto halfVal = rewriter.create<arith::ConstantOp>(
1471             loc, rewriter.getF32FloatAttr(0.5f));
1472         yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1473                                                dy, halfVal);
1474         xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1475                                                dx, halfVal);
1476       } else {
1477         auto halfVal = rewriter.create<arith::ConstantOp>(
1478             loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
1479         yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1480                                                dy, halfVal);
1481         xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1482                                                dx, halfVal);
1483       }
1484 
1485       auto zeroVal = rewriter.create<arith::ConstantOp>(
1486           loc, rewriter.getI32IntegerAttr(0));
1487       auto oneVal = rewriter.create<arith::ConstantOp>(
1488           loc, rewriter.getI32IntegerAttr(1));
1489 
1490       auto yOffset =
1491           rewriter.create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
1492       auto xOffset =
1493           rewriter.create<arith::SelectOp>(loc, xPred, oneVal, zeroVal);
1494 
1495       iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
1496       ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
1497 
1498       // Clamp the to be within the bounds of the input image.
1499 
1500       iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
1501                                       arith::CmpIPredicate::slt, rewriter);
1502       ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
1503                                       arith::CmpIPredicate::slt, rewriter);
1504 
1505       // Read the value from the input array.
1506       iy =
1507           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
1508       ix =
1509           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
1510 
1511       Value result = rewriter.create<tensor::ExtractOp>(
1512           loc, input, ValueRange{batch, iy, ix, channel});
1513 
1514       rewriter.create<linalg::YieldOp>(loc, result);
1515 
1516       return success();
1517     }
1518 
1519     if (op.getMode() == "BILINEAR") {
1520       Value y0 = iy;
1521       Value x0 = ix;
1522 
1523       auto oneVal = rewriter.create<arith::ConstantOp>(
1524           loc, rewriter.getI32IntegerAttr(1));
1525       Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
1526       Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
1527 
1528       y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
1529                                       arith::CmpIPredicate::slt, rewriter);
1530       y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
1531                                       arith::CmpIPredicate::slt, rewriter);
1532 
1533       x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
1534                                       arith::CmpIPredicate::slt, rewriter);
1535       x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
1536                                       arith::CmpIPredicate::slt, rewriter);
1537 
1538       y0 =
1539           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
1540       y1 =
1541           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
1542       x0 =
1543           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
1544       x1 =
1545           rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
1546 
1547       Value y0x0 = rewriter.create<tensor::ExtractOp>(
1548           loc, input, ValueRange{batch, y0, x0, channel});
1549       Value y0x1 = rewriter.create<tensor::ExtractOp>(
1550           loc, input, ValueRange{batch, y0, x1, channel});
1551       Value y1x0 = rewriter.create<tensor::ExtractOp>(
1552           loc, input, ValueRange{batch, y1, x0, channel});
1553       Value y1x1 = rewriter.create<tensor::ExtractOp>(
1554           loc, input, ValueRange{batch, y1, x1, channel});
1555 
1556       if (floatingPointMode) {
1557         auto oneVal = rewriter.create<arith::ConstantOp>(
1558             loc, rewriter.getF32FloatAttr(1.f));
1559         Value rightPart = dx;
1560         Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
1561 
1562         y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
1563         y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
1564         Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
1565 
1566         y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
1567         y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
1568         Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
1569 
1570         Value bottomPart = dy;
1571         Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
1572         topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
1573         bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
1574         Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
1575 
1576         rewriter.create<linalg::YieldOp>(loc, result);
1577         return success();
1578       }
1579       y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
1580       y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
1581       y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
1582       y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
1583 
1584       if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1585         dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
1586         dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
1587       }
1588 
1589       auto unitVal = rewriter.create<arith::ConstantOp>(
1590           loc, rewriter.getIntegerAttr(resultElementTy, 1LL << shift));
1591       Value rightPart = dx;
1592       Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
1593 
1594       y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
1595       y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
1596       Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
1597 
1598       y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
1599       y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
1600       Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
1601 
1602       Value bottomPart = dy;
1603       Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
1604       topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
1605       bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
1606       Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
1607 
1608       rewriter.create<linalg::YieldOp>(loc, result);
1609       return success();
1610     }
1611     return failure();
1612   }
1613 };
1614 
1615 // At the codegen level any identity operations should be removed. Any cases
1616 // where identity is load-bearing (e.g. cross device computation) should be
1617 // handled before lowering to codegen.
1618 template <typename SrcOp>
1619 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1620 public:
1621   using OpRewritePattern<SrcOp>::OpRewritePattern;
1622 
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const1623   LogicalResult matchAndRewrite(SrcOp op,
1624                                 PatternRewriter &rewriter) const final {
1625     rewriter.replaceOp(op, op.getOperation()->getOperands());
1626     return success();
1627   }
1628 };
1629 
1630 template <typename SrcOp>
1631 class ReduceConverter : public OpRewritePattern<SrcOp> {
1632 public:
1633   using OpRewritePattern<SrcOp>::OpRewritePattern;
1634 
matchAndRewrite(SrcOp reduceOp,PatternRewriter & rewriter) const1635   LogicalResult matchAndRewrite(SrcOp reduceOp,
1636                                 PatternRewriter &rewriter) const final {
1637     return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1638   }
1639 };
1640 
1641 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1642   using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
1643 
1644   LogicalResult
matchAndRewrite__anonadc6429f0411::ConcatConverter1645   matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
1646                   ConversionPatternRewriter &rewriter) const override {
1647     auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
1648     auto resultType = op.getType().dyn_cast<RankedTensorType>();
1649 
1650     Location loc = op.getLoc();
1651     int axis = op.getAxis();
1652     Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
1653         loc, rewriter.getIndexAttr(axis));
1654     int rank = resultType.getRank();
1655     SmallVector<Value, 3> offsets, sizes, strides;
1656     sizes.reserve(rank);
1657     strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
1658     offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
1659 
1660     SmallVector<Value> dynDims;
1661     for (int i = 0; i < rank; ++i) {
1662       sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
1663           loc, adaptor.getOperands()[0], i));
1664       if (inputType.isDynamicDim(i)) {
1665         dynDims.push_back(
1666             rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
1667       }
1668     }
1669 
1670     Value resultDimSize = sizes[axis];
1671     for (auto arg : adaptor.getOperands().drop_front()) {
1672       auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1673       resultDimSize =
1674           rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
1675     }
1676     sizes[axis] = resultDimSize;
1677 
1678     Value init = rewriter.create<linalg::InitTensorOp>(
1679         loc, dynDims, resultType.getShape(), resultType.getElementType());
1680 
1681     Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
1682         loc, rewriter.getZeroAttr(resultType.getElementType()));
1683     Value result =
1684         rewriter
1685             .create<linalg::FillOp>(loc, ValueRange{zeroVal}, ValueRange{init})
1686             .result();
1687 
1688     auto toOpFoldResult = [](Value v) -> OpFoldResult {
1689       auto op = v.getDefiningOp<arith::ConstantIndexOp>();
1690       if (!op)
1691         return v;
1692       return op.getValue();
1693     };
1694     for (auto arg : adaptor.getOperands()) {
1695       sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1696       result = rewriter.createOrFold<tensor::InsertSliceOp>(
1697           loc, arg, result,
1698           llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
1699           llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
1700           llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
1701       offsets[axis] =
1702           rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
1703     }
1704     rewriter.replaceOp(op, result);
1705     return success();
1706   }
1707 };
1708 
1709 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1710 public:
1711   using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1712 
matchAndRewrite(tosa::ReverseOp op,PatternRewriter & rewriter) const1713   LogicalResult matchAndRewrite(tosa::ReverseOp op,
1714                                 PatternRewriter &rewriter) const final {
1715     auto loc = op.getLoc();
1716     Value input = op.getInput();
1717     auto inputTy = input.getType().template cast<ShapedType>();
1718     auto resultTy = op.getType().template cast<ShapedType>();
1719     auto axis = op.getAxis();
1720 
1721     SmallVector<Value> dynDims;
1722     for (int i = 0; i < inputTy.getRank(); i++) {
1723       if (inputTy.isDynamicDim(i)) {
1724         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1725       }
1726     }
1727 
1728     Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1729 
1730     // First fill the output buffer with the init value.
1731     auto initTensor = rewriter
1732                           .create<linalg::InitTensorOp>(
1733                               loc, ArrayRef<Value>({dynDims}),
1734                               inputTy.getShape(), inputTy.getElementType())
1735                           .result();
1736     SmallVector<AffineMap, 2> affineMaps = {
1737         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1738 
1739     rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1740         op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
1741         getNParallelLoopsAttrs(resultTy.getRank()),
1742         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1743           llvm::SmallVector<Value> indices;
1744           for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1745             auto index =
1746                 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1747             if (i == axis) {
1748               auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1749               auto sizeMinusOne =
1750                   rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1751               index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1752                                                      index);
1753             }
1754 
1755             indices.push_back(index);
1756           }
1757 
1758           auto extract = nestedBuilder.create<tensor::ExtractOp>(
1759               nestedLoc, input, indices);
1760           nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1761                                                 extract.getResult());
1762         });
1763     return success();
1764   }
1765 };
1766 
1767 // This converter translate a tile operation to a reshape, broadcast, reshape.
1768 // The first reshape minimally expands each tiled dimension to include a
1769 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1770 // multiple.
1771 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1772   using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1773 
1774   LogicalResult
matchAndRewrite__anonadc6429f0411::TileConverter1775   matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1776                   ConversionPatternRewriter &rewriter) const override {
1777     auto loc = op.getLoc();
1778     auto input = op.getInput1();
1779     auto inputTy = input.getType().cast<ShapedType>();
1780     auto inputShape = inputTy.getShape();
1781     auto resultTy = op.getType().cast<ShapedType>();
1782     auto elementTy = inputTy.getElementType();
1783     int64_t rank = inputTy.getRank();
1784 
1785     SmallVector<int64_t> multiples;
1786     getValuesFromIntArrayAttribute(op.getMultiples(), multiples);
1787 
1788     // Broadcast the newly added dimensions to their appropriate multiple.
1789     SmallVector<int64_t, 2> genericShape;
1790     for (int i = 0; i < rank; i++) {
1791       genericShape.push_back(multiples[i]);
1792       genericShape.push_back(inputShape[i]);
1793     }
1794 
1795     SmallVector<Value> dynDims;
1796     for (int i = 0; i < inputTy.getRank(); i++) {
1797       if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1798         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1799       }
1800     }
1801 
1802     auto initTensor = rewriter.create<linalg::InitTensorOp>(
1803         op.getLoc(), dynDims, genericShape, elementTy);
1804 
1805     // We needs to map the input shape to the non-broadcasted dimensions.
1806     SmallVector<AffineExpr, 4> dimExprs;
1807     dimExprs.reserve(rank);
1808     for (unsigned i = 0; i < rank; ++i)
1809       dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1810 
1811     auto readAffineMap =
1812         AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1813                        rewriter.getContext());
1814 
1815     SmallVector<AffineMap, 2> affineMaps = {
1816         readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1817 
1818     auto genericOp = rewriter.create<linalg::GenericOp>(
1819         loc, RankedTensorType::get(genericShape, elementTy), input,
1820         ValueRange{initTensor}, affineMaps,
1821         getNParallelLoopsAttrs(genericShape.size()),
1822         [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1823           nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1824         });
1825 
1826     rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1827         op, resultTy, genericOp.getResult(0),
1828         rewriter.getI64ArrayAttr(resultTy.getShape()));
1829     return success();
1830   }
1831 };
1832 
1833 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1834 public:
1835   using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
1836 
matchAndRewrite(tosa::PadOp padOp,PatternRewriter & rewriter) const1837   LogicalResult matchAndRewrite(tosa::PadOp padOp,
1838                                 PatternRewriter &rewriter) const final {
1839     auto loc = padOp.getLoc();
1840     auto input = padOp.getInput1();
1841     auto padding = padOp.getPadding();
1842 
1843     ShapedType inputTy = input.getType().cast<ShapedType>();
1844     Type elementTy = inputTy.getElementType();
1845     int64_t rank = inputTy.getRank();
1846 
1847     // Setup the default constantAttr.
1848 
1849     Value padConstant;
1850 
1851     if (padOp.getPadConst()) {
1852       padConstant = rewriter.createOrFold<tensor::ExtractOp>(
1853           loc, padOp.getPadConst(), ValueRange({}));
1854     } else {
1855       Attribute constantAttr;
1856       if (elementTy.isa<FloatType>()) {
1857         constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1858       } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
1859         constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1860       } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
1861         int64_t value = padOp.getQuantizationInfo()->getInputZp();
1862         constantAttr = rewriter.getIntegerAttr(elementTy, value);
1863       }
1864       if (constantAttr)
1865         padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
1866     }
1867 
1868     if (!padConstant) {
1869       return rewriter.notifyMatchFailure(
1870           padOp, "tosa.pad was unable to determine the pad constant value.");
1871     }
1872 
1873     Value lowIndex =
1874         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
1875     Value highIndex =
1876         rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
1877 
1878     SmallVector<OpFoldResult, 3> lowValues;
1879     SmallVector<OpFoldResult, 3> highValues;
1880 
1881     lowValues.reserve(rank);
1882     highValues.reserve(rank);
1883 
1884     for (int i = 0; i < rank; i++) {
1885       Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
1886       Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
1887           loc, padding, ValueRange({inputIndex, lowIndex}));
1888       Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
1889           loc, padding, ValueRange({inputIndex, highIndex}));
1890 
1891       lowVal = rewriter.createOrFold<arith::IndexCastOp>(
1892           loc, rewriter.getIndexType(), lowVal);
1893       highVal = rewriter.createOrFold<arith::IndexCastOp>(
1894           loc, rewriter.getIndexType(), highVal);
1895 
1896       lowValues.push_back(lowVal);
1897       highValues.push_back(highVal);
1898     }
1899 
1900     auto newPadOp = tensor::createPadScalarOp(
1901         padOp.getType(), input, padConstant, lowValues, highValues,
1902         /*nofold=*/false, loc, rewriter);
1903 
1904     rewriter.replaceOp(padOp, newPadOp.getResult());
1905     return success();
1906   }
1907 };
1908 
1909 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1910 // op, producing two output buffers.
1911 //
1912 // The first output buffer contains the index of the found maximum value. It is
1913 // initialized to 0 and is resulting integer type.
1914 //
1915 // The second output buffer contains the maximum value found. It is initialized
1916 // to the minimum representable value of the input element type. After being
1917 // populated by indexed_generic, this buffer is disgarded as only the index is
1918 // requested.
1919 //
1920 // The indexed_generic op updates both the maximum value and index if the
1921 // current value exceeds the running max.
1922 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1923 public:
1924   using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
1925 
matchAndRewrite(tosa::ArgMaxOp argmaxOp,PatternRewriter & rewriter) const1926   LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1927                                 PatternRewriter &rewriter) const final {
1928     auto loc = argmaxOp.getLoc();
1929     Value input = argmaxOp.getInput();
1930     auto inputTy = input.getType().cast<ShapedType>();
1931     auto resultTy = argmaxOp.getOutput().getType().cast<ShapedType>();
1932     auto inElementTy = inputTy.getElementType();
1933     auto outElementTy = resultTy.getElementType();
1934     int axis = argmaxOp.getAxis();
1935     auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1936 
1937     if (!outElementTy.isa<IntegerType>())
1938       return rewriter.notifyMatchFailure(
1939           argmaxOp,
1940           "tosa.arg_max to linalg.* requires integer-like result type");
1941 
1942     SmallVector<Value> dynDims;
1943     for (int i = 0; i < inputTy.getRank(); i++) {
1944       if (inputTy.isDynamicDim(i) && i != axis) {
1945         dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1946       }
1947     }
1948 
1949     // First fill the output buffer for the index.
1950     auto initTensorIdx =
1951         rewriter
1952             .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
1953                                           outElementTy)
1954             .result();
1955     auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1956         loc, rewriter.getIntegerAttr(outElementTy, 0));
1957     auto filledTensorIdx =
1958         rewriter
1959             .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1960                                     ValueRange{initTensorIdx})
1961             .result();
1962 
1963     // Second fill the output buffer for the running max.
1964     auto initTensorMax = rewriter
1965                              .create<linalg::InitTensorOp>(
1966                                  loc, dynDims, resultTy.getShape(), inElementTy)
1967                              .result();
1968     auto fillValueMaxAttr =
1969         createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1970 
1971     if (!fillValueMaxAttr)
1972       return rewriter.notifyMatchFailure(
1973           argmaxOp, "unsupported tosa.argmax element type");
1974 
1975     auto fillValueMax =
1976         rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
1977     auto filledTensorMax =
1978         rewriter
1979             .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
1980                                     ValueRange{initTensorMax})
1981             .result();
1982 
1983     // We need to reduce along the arg-max axis, with parallel operations along
1984     // the rest.
1985     SmallVector<StringRef, 4> iteratorTypes;
1986     iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
1987     iteratorTypes[axis] = getReductionIteratorTypeName();
1988 
1989     SmallVector<AffineExpr, 2> srcExprs;
1990     SmallVector<AffineExpr, 2> dstExprs;
1991     for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
1992       srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1993       if (axis != i)
1994         dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1995     }
1996 
1997     bool didEncounterError = false;
1998     auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
1999     auto linalgOp = rewriter.create<linalg::GenericOp>(
2000         loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2001         ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2002         [&](OpBuilder &nestedBuilder, Location nestedLoc,
2003             ValueRange blockArgs) {
2004           auto newValue = blockArgs[0];
2005           auto oldIndex = blockArgs[1];
2006           auto oldValue = blockArgs[2];
2007 
2008           Value newIndex = rewriter.create<arith::IndexCastOp>(
2009               nestedLoc, oldIndex.getType(),
2010               rewriter.create<linalg::IndexOp>(loc, axis));
2011 
2012           Value predicate;
2013           if (inElementTy.isa<FloatType>()) {
2014             predicate = rewriter.create<arith::CmpFOp>(
2015                 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2016           } else if (inElementTy.isa<IntegerType>()) {
2017             predicate = rewriter.create<arith::CmpIOp>(
2018                 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2019           } else {
2020             didEncounterError = true;
2021             return;
2022           }
2023 
2024           auto resultMax = rewriter.create<arith::SelectOp>(
2025               nestedLoc, predicate, newValue, oldValue);
2026           auto resultIndex = rewriter.create<arith::SelectOp>(
2027               nestedLoc, predicate, newIndex, oldIndex);
2028           nestedBuilder.create<linalg::YieldOp>(
2029               nestedLoc, ValueRange({resultIndex, resultMax}));
2030         });
2031 
2032     if (didEncounterError)
2033       return rewriter.notifyMatchFailure(
2034           argmaxOp, "unsupported tosa.argmax element type");
2035 
2036     rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2037     return success();
2038   }
2039 };
2040 
2041 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2042 public:
2043   using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2044   LogicalResult
matchAndRewrite(tosa::GatherOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2045   matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2046                   ConversionPatternRewriter &rewriter) const final {
2047     auto input = adaptor.getOperands()[0];
2048     auto indices = adaptor.getOperands()[1];
2049 
2050     auto resultTy = op.getType().cast<ShapedType>();
2051 
2052     auto dynamicDimsOr = checkHasDynamicBatchDims(
2053         rewriter, op, {input, indices, op.getOutput()});
2054     if (!dynamicDimsOr.has_value())
2055       return failure();
2056     SmallVector<Value> dynamicDims = dynamicDimsOr.value();
2057 
2058     auto resultElementTy = resultTy.getElementType();
2059 
2060     auto loc = op.getLoc();
2061 
2062     auto initTensor =
2063         rewriter
2064             .create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
2065                                           resultElementTy)
2066             .result();
2067 
2068     SmallVector<AffineMap, 2> affineMaps = {
2069         AffineMap::get(
2070             /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2071             {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2072             rewriter.getContext()),
2073         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2074 
2075     auto genericOp = rewriter.create<linalg::GenericOp>(
2076         loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2077         ValueRange{initTensor}, affineMaps,
2078         getNParallelLoopsAttrs(resultTy.getRank()),
2079         [&](OpBuilder &b, Location loc, ValueRange args) {
2080           auto indexValue = args[0];
2081           auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2082           Value index1 = rewriter.create<arith::IndexCastOp>(
2083               loc, rewriter.getIndexType(), indexValue);
2084           auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2085           Value extract = rewriter.create<tensor::ExtractOp>(
2086               loc, input, ValueRange{index0, index1, index2});
2087           rewriter.create<linalg::YieldOp>(loc, extract);
2088         });
2089     rewriter.replaceOp(op, genericOp.getResult(0));
2090     return success();
2091   }
2092 };
2093 
2094 // Lowerings the TableOp to a series of gathers and numerica operations. This
2095 // includes interpolation between the high/low values. For the I8 varient, this
2096 // simplifies to a single gather operation.
2097 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2098 public:
2099   using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2100 
matchAndRewrite(tosa::TableOp op,PatternRewriter & rewriter) const2101   LogicalResult matchAndRewrite(tosa::TableOp op,
2102                                 PatternRewriter &rewriter) const final {
2103     auto loc = op.getLoc();
2104     Value input = op.getInput();
2105     Value table = op.getTable();
2106     auto inputTy = input.getType().cast<ShapedType>();
2107     auto tableTy = table.getType().cast<ShapedType>();
2108     auto resultTy = op.getType().cast<ShapedType>();
2109 
2110     auto inputElementTy = inputTy.getElementType();
2111     auto tableElementTy = tableTy.getElementType();
2112     auto resultElementTy = resultTy.getElementType();
2113 
2114     SmallVector<Value> dynDims;
2115     for (int i = 0; i < resultTy.getRank(); ++i) {
2116       if (inputTy.isDynamicDim(i)) {
2117         dynDims.push_back(
2118             rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2119       }
2120     }
2121 
2122     auto initTensor =
2123         rewriter
2124             .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
2125                                           resultElementTy)
2126             .result();
2127 
2128     SmallVector<AffineMap, 2> affineMaps = {
2129         rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2130         rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2131 
2132     auto genericOp = rewriter.create<linalg::GenericOp>(
2133         loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2134         getNParallelLoopsAttrs(resultTy.getRank()));
2135     rewriter.replaceOp(op, genericOp.getResult(0));
2136 
2137     {
2138       OpBuilder::InsertionGuard regionGuard(rewriter);
2139       Block *block = rewriter.createBlock(
2140           &genericOp.region(), genericOp.region().end(),
2141           TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2142 
2143       auto inputValue = block->getArgument(0);
2144       rewriter.setInsertionPointToStart(block);
2145       if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2146           resultElementTy.isInteger(8)) {
2147         Value index = rewriter.create<arith::IndexCastOp>(
2148             loc, rewriter.getIndexType(), inputValue);
2149         Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2150         index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2151                                                index, offset);
2152         Value extract =
2153             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2154         rewriter.create<linalg::YieldOp>(loc, extract);
2155         return success();
2156       }
2157 
2158       if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2159           resultElementTy.isInteger(32)) {
2160         Value extend = rewriter.create<arith::ExtSIOp>(
2161             loc, rewriter.getI32Type(), inputValue);
2162 
2163         auto offset = rewriter.create<arith::ConstantOp>(
2164             loc, rewriter.getI32IntegerAttr(32768));
2165         auto seven = rewriter.create<arith::ConstantOp>(
2166             loc, rewriter.getI32IntegerAttr(7));
2167         auto one = rewriter.create<arith::ConstantOp>(
2168             loc, rewriter.getI32IntegerAttr(1));
2169         auto b1111111 = rewriter.create<arith::ConstantOp>(
2170             loc, rewriter.getI32IntegerAttr(127));
2171 
2172         // Compute the index and fractional part from the input value:
2173         // value = value + 32768
2174         // index = value >> 7;
2175         // fraction = 0x01111111 & value
2176         auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2177         Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2178         Value fraction =
2179             rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2180 
2181         // Extract the base and next values from the table.
2182         // base = (int32_t) table[index];
2183         // next = (int32_t) table[index + 1];
2184         Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2185 
2186         index = rewriter.create<arith::IndexCastOp>(
2187             loc, rewriter.getIndexType(), index);
2188         indexPlusOne = rewriter.create<arith::IndexCastOp>(
2189             loc, rewriter.getIndexType(), indexPlusOne);
2190 
2191         Value base =
2192             rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2193         Value next = rewriter.create<tensor::ExtractOp>(
2194             loc, table, ValueRange{indexPlusOne});
2195 
2196         base =
2197             rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2198         next =
2199             rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2200 
2201         // Use the fractional part to interpolate between the input values:
2202         // result = (base << 7) + (next - base) * fraction
2203         Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2204         Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2205         Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2206         Value result =
2207             rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2208 
2209         rewriter.create<linalg::YieldOp>(loc, result);
2210 
2211         return success();
2212       }
2213     }
2214 
2215     return rewriter.notifyMatchFailure(
2216         op, "unable to create body for tosa.table op");
2217   }
2218 };
2219 
2220 } // namespace
2221 
populateTosaToLinalgConversionPatterns(RewritePatternSet * patterns)2222 void mlir::tosa::populateTosaToLinalgConversionPatterns(
2223     RewritePatternSet *patterns) {
2224   patterns->add<
2225       // clang-format off
2226       PointwiseConverter<tosa::AddOp>,
2227       PointwiseConverter<tosa::SubOp>,
2228       PointwiseConverter<tosa::MulOp>,
2229       PointwiseConverter<tosa::DivOp>,
2230       PointwiseConverter<tosa::NegateOp>,
2231       PointwiseConverter<tosa::PowOp>,
2232       PointwiseConverter<tosa::ReciprocalOp>,
2233       PointwiseConverter<tosa::RsqrtOp>,
2234       PointwiseConverter<tosa::LogOp>,
2235       PointwiseConverter<tosa::ExpOp>,
2236       PointwiseConverter<tosa::AbsOp>,
2237       PointwiseConverter<tosa::TanhOp>,
2238       PointwiseConverter<tosa::BitwiseAndOp>,
2239       PointwiseConverter<tosa::BitwiseOrOp>,
2240       PointwiseConverter<tosa::BitwiseNotOp>,
2241       PointwiseConverter<tosa::BitwiseXorOp>,
2242       PointwiseConverter<tosa::LogicalAndOp>,
2243       PointwiseConverter<tosa::LogicalNotOp>,
2244       PointwiseConverter<tosa::LogicalOrOp>,
2245       PointwiseConverter<tosa::LogicalXorOp>,
2246       PointwiseConverter<tosa::CastOp>,
2247       PointwiseConverter<tosa::LogicalLeftShiftOp>,
2248       PointwiseConverter<tosa::LogicalRightShiftOp>,
2249       PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2250       PointwiseConverter<tosa::ClzOp>,
2251       PointwiseConverter<tosa::SelectOp>,
2252       PointwiseConverter<tosa::GreaterOp>,
2253       PointwiseConverter<tosa::GreaterEqualOp>,
2254       PointwiseConverter<tosa::EqualOp>,
2255       PointwiseConverter<tosa::MaximumOp>,
2256       PointwiseConverter<tosa::MinimumOp>,
2257       PointwiseConverter<tosa::CeilOp>,
2258       PointwiseConverter<tosa::FloorOp>,
2259       PointwiseConverter<tosa::ClampOp>,
2260       PointwiseConverter<tosa::ReluNOp>,
2261       PointwiseConverter<tosa::SigmoidOp>,
2262       IdentityNConverter<tosa::IdentityOp>,
2263       ReduceConverter<tosa::ReduceAllOp>,
2264       ReduceConverter<tosa::ReduceAnyOp>,
2265       ReduceConverter<tosa::ReduceMinOp>,
2266       ReduceConverter<tosa::ReduceMaxOp>,
2267       ReduceConverter<tosa::ReduceSumOp>,
2268       ReduceConverter<tosa::ReduceProdOp>,
2269       ArgMaxConverter,
2270       ConcatConverter,
2271       GatherConverter,
2272       PadConverter,
2273       ReshapeConverterCollapse,
2274       ReshapeConverterExpand,
2275       ReshapeConverterCollapseExpand,
2276       RescaleConverter,
2277       ResizeConverter,
2278       ReverseConverter,
2279       TableConverter,
2280       TileConverter,
2281       TransposeConverter>(patterns->getContext());
2282   // clang-format on
2283 }
2284