1 //===- TosaToLinalg.cpp - Lowering Tosa to Linalg Dialect -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Utils/Utils.h"
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21 #include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
22 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 #include <numeric>
29
30 using namespace mlir;
31 using namespace mlir::tosa;
32
33 template <typename T>
34 static arith::ConstantOp
createConstFromIntAttribute(Operation * op,const std::string & attrName,Type requiredAttrType,OpBuilder & rewriter)35 createConstFromIntAttribute(Operation *op, const std::string &attrName,
36 Type requiredAttrType, OpBuilder &rewriter) {
37 auto castedN = static_cast<T>(
38 op->getAttr(attrName).cast<IntegerAttr>().getValue().getSExtValue());
39 return rewriter.create<arith::ConstantOp>(
40 op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
41 }
42
43 static Value
createLinalgBodyCalculationForElementwiseOp(Operation * op,ValueRange args,ArrayRef<Type> resultTypes,PatternRewriter & rewriter)44 createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
45 ArrayRef<Type> resultTypes,
46 PatternRewriter &rewriter) {
47 Location loc = op->getLoc();
48 auto elementTy =
49 op->getOperand(0).getType().cast<ShapedType>().getElementType();
50
51 // tosa::AbsOp
52 if (isa<tosa::AbsOp>(op) && elementTy.isa<FloatType>())
53 return rewriter.create<math::AbsOp>(loc, resultTypes, args);
54
55 if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
56 auto zero = rewriter.create<arith::ConstantOp>(
57 loc, rewriter.getZeroAttr(elementTy));
58 auto cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
59 args[0], zero);
60 auto neg = rewriter.create<arith::SubIOp>(loc, zero, args[0]);
61 return rewriter.create<arith::SelectOp>(loc, cmp, args[0], neg);
62 }
63
64 // tosa::AddOp
65 if (isa<tosa::AddOp>(op) && elementTy.isa<FloatType>())
66 return rewriter.create<arith::AddFOp>(loc, resultTypes, args);
67
68 if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
69 return rewriter.create<arith::AddIOp>(loc, resultTypes, args);
70
71 // tosa::SubOp
72 if (isa<tosa::SubOp>(op) && elementTy.isa<FloatType>())
73 return rewriter.create<arith::SubFOp>(loc, resultTypes, args);
74
75 if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
76 return rewriter.create<arith::SubIOp>(loc, resultTypes, args);
77
78 // tosa::MulOp
79 if (isa<tosa::MulOp>(op) && elementTy.isa<FloatType>()) {
80 if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
81 (void)rewriter.notifyMatchFailure(op,
82 "Cannot have shift value for float");
83 return nullptr;
84 }
85 return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
86 }
87
88 // tosa::DivOp
89 if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
90 return rewriter.create<arith::DivSIOp>(loc, resultTypes, args);
91
92 // tosa::ReciprocalOp
93 if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<FloatType>()) {
94 auto one =
95 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
96 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, args[0]);
97 }
98
99 if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
100 Value a = args[0];
101 Value b = args[1];
102 auto shift =
103 op->getAttr("shift").cast<IntegerAttr>().getValue().getSExtValue();
104 if (shift > 0) {
105 auto shiftConst =
106 rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
107 if (!a.getType().isInteger(32))
108 a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
109
110 if (!b.getType().isInteger(32))
111 b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
112
113 auto result = rewriter.create<tosa::ApplyScaleOp>(
114 loc, rewriter.getI32Type(), a, b, shiftConst,
115 rewriter.getBoolAttr(false));
116
117 if (elementTy.isInteger(32))
118 return result;
119
120 return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
121 }
122
123 int aWidth = a.getType().getIntOrFloatBitWidth();
124 int bWidth = b.getType().getIntOrFloatBitWidth();
125 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
126
127 if (aWidth < cWidth)
128 a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
129 if (bWidth < cWidth)
130 b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
131
132 return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
133 }
134
135 // tosa::NegateOp
136 if (isa<tosa::NegateOp>(op) && elementTy.isa<FloatType>())
137 return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
138
139 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
140 !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
141 auto constant =
142 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
143 return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
144 }
145
146 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
147 cast<tosa::NegateOp>(op).getQuantizationInfo()) {
148 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
149 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
150 int64_t inZp = quantizationInfo.value().getInputZp();
151 int64_t outZp = quantizationInfo.value().getOutputZp();
152
153 // Compute the maximum value that can occur in the intermediate buffer.
154 int64_t zpAdd = inZp + outZp;
155 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
156 std::abs(zpAdd) + 1;
157
158 // Convert that maximum value into the maximum bitwidth needed to represent
159 // it. We assume 48-bit numbers may be supported further in the pipeline.
160 int intermediateBitWidth = 64;
161 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162 intermediateBitWidth = 16;
163 } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164 intermediateBitWidth = 32;
165 } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166 intermediateBitWidth = 48;
167 }
168
169 Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
170 Value zpAddValue = rewriter.create<arith::ConstantOp>(
171 loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
172
173 // The negation can be applied by doing:
174 // outputValue = inZp + outZp - inputValue
175 auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176 auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext);
177
178 // Clamp to the negation range.
179 auto min = rewriter.create<arith::ConstantIntOp>(
180 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
181 intermediateType);
182 auto max = rewriter.create<arith::ConstantIntOp>(
183 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
184 intermediateType);
185 auto clamp = clampHelper<arith::CmpIOp>(
186 loc, sub, min, max, arith::CmpIPredicate::slt, rewriter);
187
188 // Truncate to the final value.
189 return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
190 }
191
192 // tosa::BitwiseAndOp
193 if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
194 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
195
196 // tosa::BitwiseOrOp
197 if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
198 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
199
200 // tosa::BitwiseNotOp
201 if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
202 auto allOnesAttr = rewriter.getIntegerAttr(
203 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204 auto allOnes = rewriter.create<arith::ConstantOp>(loc, allOnesAttr);
205 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
206 }
207
208 // tosa::BitwiseXOrOp
209 if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
210 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
211
212 // tosa::LogicalLeftShiftOp
213 if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
214 return rewriter.create<arith::ShLIOp>(loc, resultTypes, args);
215
216 // tosa::LogicalRightShiftOp
217 if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
218 return rewriter.create<arith::ShRUIOp>(loc, resultTypes, args);
219
220 // tosa::ArithmeticRightShiftOp
221 if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
222 auto result = rewriter.create<arith::ShRSIOp>(loc, resultTypes, args);
223 auto round = op->getAttr("round").cast<BoolAttr>().getValue();
224 if (!round) {
225 return result;
226 }
227
228 Type i1Ty = IntegerType::get(rewriter.getContext(), /*width=*/1);
229 auto one =
230 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
231 auto zero =
232 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
233 auto i1one =
234 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
235
236 // Checking that input2 != 0
237 auto shiftValueGreaterThanZero = rewriter.create<arith::CmpIOp>(
238 loc, arith::CmpIPredicate::sgt, args[1], zero);
239
240 // Checking for the last bit of input1 to be 1
241 auto subtract =
242 rewriter.create<arith::SubIOp>(loc, resultTypes, args[1], one);
243 auto shifted =
244 rewriter.create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
245 ->getResults();
246 auto truncated =
247 rewriter.create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
248 auto isInputOdd =
249 rewriter.create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
250
251 auto shouldRound = rewriter.create<arith::AndIOp>(
252 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
253 auto extended =
254 rewriter.create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255 return rewriter.create<arith::AddIOp>(loc, resultTypes, result, extended);
256 }
257
258 // tosa::ClzOp
259 if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
260 return rewriter.create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
261 }
262
263 // tosa::LogicalAnd
264 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265 return rewriter.create<arith::AndIOp>(loc, resultTypes, args);
266
267 // tosa::LogicalNot
268 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269 auto one = rewriter.create<arith::ConstantOp>(
270 loc, rewriter.getIntegerAttr(elementTy, 1));
271 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args[0], one);
272 }
273
274 // tosa::LogicalOr
275 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276 return rewriter.create<arith::OrIOp>(loc, resultTypes, args);
277
278 // tosa::LogicalXor
279 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280 return rewriter.create<arith::XOrIOp>(loc, resultTypes, args);
281
282 // tosa::PowOp
283 if (isa<tosa::PowOp>(op) && elementTy.isa<FloatType>())
284 return rewriter.create<mlir::math::PowFOp>(loc, resultTypes, args);
285
286 // tosa::RsqrtOp
287 if (isa<tosa::RsqrtOp>(op) && elementTy.isa<FloatType>())
288 return rewriter.create<mlir::math::RsqrtOp>(loc, resultTypes, args);
289
290 // tosa::LogOp
291 if (isa<tosa::LogOp>(op) && elementTy.isa<FloatType>())
292 return rewriter.create<mlir::math::LogOp>(loc, resultTypes, args);
293
294 // tosa::ExpOp
295 if (isa<tosa::ExpOp>(op) && elementTy.isa<FloatType>())
296 return rewriter.create<mlir::math::ExpOp>(loc, resultTypes, args);
297
298 // tosa::TanhOp
299 if (isa<tosa::TanhOp>(op) && elementTy.isa<FloatType>())
300 return rewriter.create<mlir::math::TanhOp>(loc, resultTypes, args);
301
302 // tosa::GreaterOp
303 if (isa<tosa::GreaterOp>(op) && elementTy.isa<FloatType>())
304 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
305 args[0], args[1]);
306
307 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
308 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
309 args[0], args[1]);
310
311 // tosa::GreaterEqualOp
312 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<FloatType>())
313 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
314 args[0], args[1]);
315
316 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
317 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
318 args[0], args[1]);
319
320 // tosa::EqualOp
321 if (isa<tosa::EqualOp>(op) && elementTy.isa<FloatType>())
322 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
323 args[0], args[1]);
324
325 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
326 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
327 args[0], args[1]);
328
329 // tosa::SelectOp
330 if (isa<tosa::SelectOp>(op)) {
331 elementTy = op->getOperand(1).getType().cast<ShapedType>().getElementType();
332 if (elementTy.isa<FloatType>() || elementTy.isa<IntegerType>())
333 return rewriter.create<arith::SelectOp>(loc, args[0], args[1], args[2]);
334 }
335
336 // tosa::MaximumOp
337 if (isa<tosa::MaximumOp>(op) && elementTy.isa<FloatType>()) {
338 auto predicate = rewriter.create<arith::CmpFOp>(
339 loc, arith::CmpFPredicate::OGT, args[0], args[1]);
340 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
341 }
342
343 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
344 auto predicate = rewriter.create<arith::CmpIOp>(
345 loc, arith::CmpIPredicate::sgt, args[0], args[1]);
346 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
347 }
348
349 // tosa::MinimumOp
350 if (isa<tosa::MinimumOp>(op) && elementTy.isa<FloatType>()) {
351 auto predicate = rewriter.create<arith::CmpFOp>(
352 loc, arith::CmpFPredicate::OLT, args[0], args[1]);
353 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
354 }
355
356 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
357 auto predicate = rewriter.create<arith::CmpIOp>(
358 loc, arith::CmpIPredicate::slt, args[0], args[1]);
359 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
360 }
361
362 // tosa::CeilOp
363 if (isa<tosa::CeilOp>(op) && elementTy.isa<FloatType>())
364 return rewriter.create<math::CeilOp>(loc, resultTypes, args);
365
366 // tosa::FloorOp
367 if (isa<tosa::FloorOp>(op) && elementTy.isa<FloatType>())
368 return rewriter.create<math::FloorOp>(loc, resultTypes, args);
369
370 // tosa::ClampOp
371 if (isa<tosa::ClampOp>(op) && elementTy.isa<FloatType>()) {
372 bool losesInfo = false;
373 APFloat min_apf = op->getAttr("min_fp").cast<FloatAttr>().getValue();
374 APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
375 min_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
376 APFloat::rmNearestTiesToEven, &losesInfo);
377 max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
378 APFloat::rmNearestTiesToEven, &losesInfo);
379 auto min = rewriter.create<arith::ConstantOp>(
380 loc, elementTy, rewriter.getFloatAttr(elementTy, min_apf));
381 auto max = rewriter.create<arith::ConstantOp>(
382 loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
383 return clampHelper<arith::CmpFOp>(loc, args[0], min, max,
384 arith::CmpFPredicate::OLT, rewriter);
385 }
386
387 if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
388 auto intTy = elementTy.cast<IntegerType>();
389 int32_t min = static_cast<int32_t>(
390 op->getAttr("min_int").cast<IntegerAttr>().getValue().getSExtValue());
391 int32_t max = static_cast<int32_t>(
392 op->getAttr("max_int").cast<IntegerAttr>().getValue().getSExtValue());
393
394 if (intTy.isUnsignedInteger()) {
395 min = std::max<int32_t>(min, 0);
396 max = std::min<int32_t>(
397 max,
398 APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
399 } else {
400 min = std::max<int32_t>(
401 min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
402 .getSExtValue());
403 max = std::min<int32_t>(
404 max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
405 .getSExtValue());
406 }
407
408 auto minVal = rewriter.create<arith::ConstantIntOp>(
409 loc, min, intTy.getIntOrFloatBitWidth());
410 auto maxVal = rewriter.create<arith::ConstantIntOp>(
411 loc, max, intTy.getIntOrFloatBitWidth());
412 return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
413 arith::CmpIPredicate::slt, rewriter);
414 }
415
416 // tosa::ReluNOp
417 if (isa<tosa::ReluNOp>(op) && elementTy.isa<FloatType>()) {
418 auto zero =
419 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
420 bool losesInfo = false;
421 APFloat max_apf = op->getAttr("max_fp").cast<FloatAttr>().getValue();
422 max_apf.convert(elementTy.cast<FloatType>().getFloatSemantics(),
423 APFloat::rmNearestTiesToEven, &losesInfo);
424 auto n = rewriter.create<arith::ConstantOp>(
425 loc, elementTy, rewriter.getFloatAttr(elementTy, max_apf));
426 return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
427 arith::CmpFPredicate::OLT, rewriter);
428 }
429
430 if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
431 auto zero =
432 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
433 auto n = createConstFromIntAttribute<int32_t>(op, "max_int", elementTy,
434 rewriter);
435 return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
436 arith::CmpIPredicate::slt, rewriter);
437 }
438
439 // tosa::SigmoidOp
440 if (isa<tosa::SigmoidOp>(op) && elementTy.isa<FloatType>()) {
441 auto one =
442 rewriter.create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
443 auto negate = rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
444 auto exp = rewriter.create<mlir::math::ExpOp>(loc, resultTypes, negate);
445 auto added = rewriter.create<arith::AddFOp>(loc, resultTypes, exp, one);
446 return rewriter.create<arith::DivFOp>(loc, resultTypes, one, added);
447 }
448
449 // tosa::CastOp
450 if (isa<tosa::CastOp>(op)) {
451 Type srcTy = elementTy;
452 Type dstTy = resultTypes.front();
453 bool bitExtend =
454 srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth();
455
456 if (srcTy == dstTy)
457 return args.front();
458
459 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && bitExtend)
460 return rewriter.create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
461
462 if (srcTy.isa<FloatType>() && dstTy.isa<FloatType>() && !bitExtend)
463 return rewriter.create<arith::TruncFOp>(loc, resultTypes, args,
464 mlir::None);
465
466 // 1-bit integers need to be treated as signless.
467 if (srcTy.isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
468 return rewriter.create<arith::UIToFPOp>(loc, resultTypes, args,
469 mlir::None);
470
471 if (srcTy.isInteger(1) && dstTy.isa<IntegerType>() && bitExtend)
472 return rewriter.create<arith::ExtUIOp>(loc, resultTypes, args,
473 mlir::None);
474
475 // Unsigned integers need an unrealized cast so that they can be passed
476 // to UIToFP.
477 if (srcTy.isUnsignedInteger() && dstTy.isa<FloatType>()) {
478 auto unrealizedCast =
479 rewriter
480 .create<UnrealizedConversionCastOp>(
481 loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()),
482 args[0])
483 .getResult(0);
484 return rewriter.create<arith::UIToFPOp>(loc, resultTypes[0],
485 unrealizedCast);
486 }
487
488 // All other si-to-fp conversions should be handled by SIToFP.
489 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
490 return rewriter.create<arith::SIToFPOp>(loc, resultTypes, args,
491 mlir::None);
492
493 // Casting to boolean, floats need to only be checked as not-equal to zero.
494 if (srcTy.isa<FloatType>() && dstTy.isInteger(1)) {
495 Value zero = rewriter.create<arith::ConstantOp>(
496 loc, rewriter.getFloatAttr(srcTy, 0.0));
497 return rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
498 args.front(), zero);
499 }
500
501 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
502 auto zero = rewriter.create<arith::ConstantOp>(
503 loc, rewriter.getF32FloatAttr(0.0f));
504 auto half = rewriter.create<arith::ConstantOp>(
505 loc, rewriter.getF32FloatAttr(0.5f));
506
507 auto intMin = rewriter.create<arith::ConstantOp>(
508 loc, rewriter.getF32FloatAttr(
509 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
510 .getSExtValue()));
511
512 auto intMax = rewriter.create<arith::ConstantOp>(
513 loc, rewriter.getF32FloatAttr(
514 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
515 .getSExtValue()));
516
517 auto added = rewriter.create<arith::AddFOp>(loc, args[0], half);
518 auto subbed = rewriter.create<arith::SubFOp>(loc, args[0], half);
519 auto negative = rewriter.create<arith::CmpFOp>(
520 loc, arith::CmpFPredicate::OLT, args[0], zero);
521 auto rounded =
522 rewriter.create<arith::SelectOp>(loc, negative, subbed, added);
523
524 auto clamped = clampHelper<arith::CmpFOp>(
525 loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
526
527 return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
528 }
529
530 // Casting to boolean, integers need to only be checked as not-equal to
531 // zero.
532 if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
533 Value zero = rewriter.create<arith::ConstantIntOp>(
534 loc, 0, srcTy.getIntOrFloatBitWidth());
535 return rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
536 args.front(), zero);
537 }
538
539 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && bitExtend)
540 return rewriter.create<arith::ExtSIOp>(loc, resultTypes, args,
541 mlir::None);
542
543 if (srcTy.isa<IntegerType>() && dstTy.isa<IntegerType>() && !bitExtend) {
544 auto intMin = rewriter.create<arith::ConstantIntOp>(
545 loc,
546 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
547 .getSExtValue(),
548 srcTy.getIntOrFloatBitWidth());
549
550 auto intMax = rewriter.create<arith::ConstantIntOp>(
551 loc,
552 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())
553 .getSExtValue(),
554 srcTy.getIntOrFloatBitWidth());
555
556 auto clamped = clampHelper<arith::CmpIOp>(
557 loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
558 return rewriter.create<arith::TruncIOp>(loc, dstTy, clamped);
559 }
560 }
561
562 (void)rewriter.notifyMatchFailure(
563 op, "unhandled op for linalg body calculation for elementwise op");
564 return nullptr;
565 }
566
567 static LogicalResult
elementwiseMatchAndRewriteHelper(Operation * operation,PatternRewriter & rewriter)568 elementwiseMatchAndRewriteHelper(Operation *operation,
569 PatternRewriter &rewriter) {
570 auto loc = operation->getLoc();
571
572 assert(operation->getNumResults() == 1 &&
573 "All TOSA elementwise ops should only return a single result.");
574
575 auto results = operation->getResults();
576 auto resultTy = operation->getResult(0).getType().dyn_cast<ShapedType>();
577
578 if (!resultTy)
579 return rewriter.notifyMatchFailure(operation,
580 "All results must be a shaped type");
581
582 unsigned rank = resultTy.getRank();
583
584 // Construct the indexing maps needed for linalg.generic ops.
585 SmallVector<Type> bodyArgTypes;
586
587 for (Value in : operation->getOperands())
588 bodyArgTypes.emplace_back(getElementTypeOrSelf(in.getType()));
589
590 SmallVector<Type> opResultTypes;
591 SmallVector<Value> initTensors;
592
593 SmallVector<Value> dynDims;
594 dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
595
596 for (auto arg : operation->getOperands()) {
597 auto operandTy = arg.getType().cast<ShapedType>();
598 for (int i = 0; i < operandTy.getRank(); i++) {
599 if (operandTy.isDynamicDim(i) && !dynDims[i])
600 dynDims[i] = rewriter.create<tensor::DimOp>(loc, arg, i);
601 }
602 }
603
604 SmallVector<Value> filteredDims = condenseValues(dynDims);
605
606 for (auto result : results) {
607 auto resultTy = result.getType().template cast<ShapedType>();
608 initTensors.push_back(rewriter.create<linalg::InitTensorOp>(
609 loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
610 opResultTypes.push_back(result.getType());
611 }
612
613 auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
614 initTensors, [](Value v) { return getElementTypeOrSelf(v); }));
615
616 SmallVector<Value, 2> operands;
617 SmallVector<AffineMap, 2> indexingMaps;
618 indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size());
619
620 // Input indexing maps may be broadcasted.
621 for (Value operand : operation->getOperands()) {
622 ShapedType type = operand.getType().cast<ShapedType>();
623
624 if (type.getShape() == resultTy.getShape()) {
625 operands.push_back(operand);
626 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
627 continue;
628 }
629
630 SmallVector<int64_t, 5> newShape;
631 SmallVector<AffineExpr, 4> affineExprs;
632 newShape.reserve(type.getRank());
633 for (const auto &it : llvm::enumerate(type.getShape())) {
634 if (it.value() == resultTy.getDimSize(it.index())) {
635 newShape.push_back(it.value());
636 affineExprs.push_back(
637 mlir::getAffineDimExpr(it.index(), rewriter.getContext()));
638 }
639 }
640
641 if (newShape.size() != rank) {
642 operand = rewriter.create<tosa::ReshapeOp>(
643 loc, RankedTensorType::get(newShape, type.getElementType()), operand,
644 rewriter.getI64ArrayAttr(newShape));
645 }
646
647 operands.push_back(operand);
648 indexingMaps.push_back(AffineMap::get(
649 /*dimCount=*/type.getRank(), /*symbolCount=*/0, affineExprs,
650 rewriter.getContext()));
651 }
652
653 indexingMaps.append(operation->getNumResults(),
654 rewriter.getMultiDimIdentityMap(rank));
655
656 bool didEncounterError = false;
657 auto linalgOp = rewriter.create<linalg::GenericOp>(
658 loc, opResultTypes, operands, initTensors, indexingMaps,
659 getNParallelLoopsAttrs(rank),
660 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
661 Value opResult = createLinalgBodyCalculationForElementwiseOp(
662 operation, blockArgs.take_front(operation->getNumOperands()),
663 bodyResultTypes, rewriter);
664 if (!opResult) {
665 didEncounterError = true;
666 return;
667 }
668 nestedBuilder.create<linalg::YieldOp>(loc, opResult);
669 });
670
671 if (didEncounterError)
672 return failure();
673
674 rewriter.replaceOp(operation, linalgOp->getResults());
675 return success();
676 }
677
678 // Returns the constant initial value for a given reduction operation. The
679 // attribute type varies depending on the element type required.
createInitialValueForReduceOp(Operation * op,Type elementTy,PatternRewriter & rewriter)680 static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy,
681 PatternRewriter &rewriter) {
682 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>())
683 return rewriter.getFloatAttr(elementTy, 0.0);
684
685 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>())
686 return rewriter.getIntegerAttr(elementTy, 0);
687
688 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>())
689 return rewriter.getFloatAttr(elementTy, 1.0);
690
691 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>())
692 return rewriter.getIntegerAttr(elementTy, 1);
693
694 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>())
695 return rewriter.getFloatAttr(
696 elementTy, APFloat::getLargest(
697 elementTy.cast<FloatType>().getFloatSemantics(), false));
698
699 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>())
700 return rewriter.getIntegerAttr(
701 elementTy, APInt::getSignedMaxValue(elementTy.getIntOrFloatBitWidth()));
702
703 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>())
704 return rewriter.getFloatAttr(
705 elementTy, APFloat::getLargest(
706 elementTy.cast<FloatType>().getFloatSemantics(), true));
707
708 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>())
709 return rewriter.getIntegerAttr(
710 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
711
712 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
713 return rewriter.getIntegerAttr(elementTy, APInt::getAllOnes(1));
714
715 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
716 return rewriter.getIntegerAttr(elementTy, APInt::getZero(1));
717
718 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<FloatType>())
719 return rewriter.getFloatAttr(
720 elementTy, APFloat::getLargest(
721 elementTy.cast<FloatType>().getFloatSemantics(), true));
722
723 if (isa<tosa::ArgMaxOp>(op) && elementTy.isa<IntegerType>())
724 return rewriter.getIntegerAttr(
725 elementTy, APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth()));
726
727 return {};
728 }
729
730 // Creates the body calculation for a reduction. The operations vary depending
731 // on the input type.
createLinalgBodyCalculationForReduceOp(Operation * op,ValueRange args,Type elementTy,PatternRewriter & rewriter)732 static Value createLinalgBodyCalculationForReduceOp(Operation *op,
733 ValueRange args,
734 Type elementTy,
735 PatternRewriter &rewriter) {
736 Location loc = op->getLoc();
737 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<FloatType>()) {
738 return rewriter.create<arith::AddFOp>(loc, args);
739 }
740
741 if (isa<tosa::ReduceSumOp>(op) && elementTy.isa<IntegerType>()) {
742 return rewriter.create<arith::AddIOp>(loc, args);
743 }
744
745 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<FloatType>()) {
746 return rewriter.create<arith::MulFOp>(loc, args);
747 }
748
749 if (isa<tosa::ReduceProdOp>(op) && elementTy.isa<IntegerType>()) {
750 return rewriter.create<arith::MulIOp>(loc, args);
751 }
752
753 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<FloatType>()) {
754 auto predicate = rewriter.create<arith::CmpFOp>(
755 loc, arith::CmpFPredicate::OLT, args[0], args[1]);
756 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
757 }
758
759 if (isa<tosa::ReduceMinOp>(op) && elementTy.isa<IntegerType>()) {
760 auto predicate = rewriter.create<arith::CmpIOp>(
761 loc, arith::CmpIPredicate::slt, args[0], args[1]);
762 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
763 }
764
765 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<FloatType>()) {
766 auto predicate = rewriter.create<arith::CmpFOp>(
767 loc, arith::CmpFPredicate::OGT, args[0], args[1]);
768 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
769 }
770
771 if (isa<tosa::ReduceMaxOp>(op) && elementTy.isa<IntegerType>()) {
772 auto predicate = rewriter.create<arith::CmpIOp>(
773 loc, arith::CmpIPredicate::sgt, args[0], args[1]);
774 return rewriter.create<arith::SelectOp>(loc, predicate, args[0], args[1]);
775 }
776
777 if (isa<tosa::ReduceAllOp>(op) && elementTy.isInteger(1))
778 return rewriter.create<arith::AndIOp>(loc, args);
779
780 if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1))
781 return rewriter.create<arith::OrIOp>(loc, args);
782
783 return {};
784 }
785
786 // Performs the match and rewrite for reduction operations. This includes
787 // declaring a correctly sized initial value, and the linalg.generic operation
788 // that reduces across the specified axis.
reduceMatchAndRewriteHelper(Operation * op,uint64_t axis,PatternRewriter & rewriter)789 static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis,
790 PatternRewriter &rewriter) {
791 auto loc = op->getLoc();
792 auto inputTy = op->getOperand(0).getType().template cast<ShapedType>();
793 auto resultTy = op->getResult(0).getType().template cast<ShapedType>();
794 auto elementTy = resultTy.getElementType();
795 Value input = op->getOperand(0);
796
797 llvm::SmallVector<int64_t> reduceShape;
798 SmallVector<Value> dynDims;
799 for (unsigned i = 0; i < inputTy.getRank(); i++) {
800 if (axis != i) {
801 reduceShape.push_back(inputTy.getDimSize(i));
802 if (inputTy.isDynamicDim(i))
803 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
804 }
805 }
806
807 Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
808
809 // First fill the output buffer with the init value.
810 auto initTensor = rewriter
811 .create<linalg::InitTensorOp>(loc, dynDims, reduceShape,
812 resultTy.getElementType())
813 .result();
814
815 auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
816 if (!fillValueAttr)
817 return rewriter.notifyMatchFailure(
818 op, "No initial value found for reduction operation");
819
820 auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr);
821 auto filledTensor = rewriter
822 .create<linalg::FillOp>(loc, ValueRange{fillValue},
823 ValueRange{initTensor})
824 .result();
825
826 SmallVector<AffineExpr, 2> srcExprs;
827 SmallVector<AffineExpr, 2> dstExprs;
828 SmallVector<StringRef, 4> iteratorTypes;
829 for (unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
830 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
831
832 iteratorTypes.push_back(axis == i ? getReductionIteratorTypeName()
833 : getParallelIteratorTypeName());
834 if (axis != i)
835 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
836 }
837
838 bool didEncounterError = false;
839 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs});
840 auto linalgOp = rewriter.create<linalg::GenericOp>(
841 loc, reduceTy, input, filledTensor, maps, iteratorTypes,
842 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
843 auto result = createLinalgBodyCalculationForReduceOp(
844 op, blockArgs, elementTy, rewriter);
845 if (result)
846 didEncounterError = true;
847
848 nestedBuilder.create<linalg::YieldOp>(loc, result);
849 });
850
851 if (!didEncounterError)
852 return failure();
853
854 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(op, resultTy,
855 linalgOp.getResults());
856 return success();
857 }
858
findIntermediateShape(ArrayRef<int64_t> lhsShape,ArrayRef<int64_t> rhsShape,SmallVector<int64_t> & intermediateShape,bool isDynamic)859 static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
860 ArrayRef<int64_t> rhsShape,
861 SmallVector<int64_t> &intermediateShape,
862 bool isDynamic) {
863 if (isDynamic) {
864 // TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
865 intermediateShape = {-1};
866 return true;
867 }
868
869 if (lhsShape.empty() || rhsShape.empty()) {
870 intermediateShape = {};
871 return true;
872 }
873
874 unsigned currLhsDim = 0, currRhsDim = 0;
875 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
876 int64_t rhsSize = rhsShape[currRhsDim];
877 int64_t lhsSize = lhsShape[currLhsDim];
878 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
879 currRhsDim < rhsShape.size()) {
880 if (lhsSize < rhsSize) {
881 currLhsDim++;
882 lhsSize *= lhsShape[currLhsDim];
883 } else {
884 currRhsDim++;
885 rhsSize *= rhsShape[currRhsDim];
886 }
887 }
888 if (lhsSize == rhsSize) {
889 intermediateShape.push_back(lhsSize);
890 }
891 currRhsDim++;
892 currLhsDim++;
893 }
894
895 // If the iterators didn't reach the end and their leftover dimensions are not
896 // equal to 1 an intermediate shape was not found.
897 while (currLhsDim < lhsShape.size()) {
898 if (lhsShape[currLhsDim++] != 1) {
899 return false;
900 }
901 }
902
903 while (currRhsDim < rhsShape.size()) {
904 if (rhsShape[currRhsDim++] != 1) {
905 return false;
906 }
907 }
908
909 return true;
910 }
911
createReassociationMapsForCollapse(PatternRewriter & rewriter,ArrayRef<int64_t> srcShape,ArrayRef<int64_t> dstShape,SmallVector<ReassociationExprs,4> & reassociationMap,bool isDynamic)912 static bool createReassociationMapsForCollapse(
913 PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
914 ArrayRef<int64_t> dstShape,
915 SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
916
917 // If the shape is dynamic, create a map for collapsing into one dimension.
918 if (isDynamic) {
919 SmallVector<AffineExpr, 2> exprs;
920 for (int i = 0, s = srcShape.size(); i < s; ++i)
921 exprs.push_back(rewriter.getAffineDimExpr(i));
922 reassociationMap = {exprs};
923 return true;
924 }
925
926 if (dstShape.empty()) {
927 reassociationMap = {};
928 return true;
929 }
930
931 reassociationMap.resize(dstShape.size());
932 unsigned currSrcDim = 0, currDstDim = 0;
933 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
934 int64_t dstSize = dstShape[currDstDim];
935 int64_t srcSize = srcShape[currSrcDim];
936 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
937 reassociationMap[currDstDim].push_back(
938 rewriter.getAffineDimExpr(currSrcDim++));
939 srcSize *= srcShape[currSrcDim];
940 }
941 if (srcSize == dstSize) {
942 reassociationMap[currDstDim].push_back(
943 rewriter.getAffineDimExpr(currSrcDim++));
944 // If the next dim in collapsedShape is not 1, treat subsequent dims in
945 // expandedShape which are 1 to be collapsed.
946 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
947 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
948 reassociationMap[currDstDim].push_back(
949 rewriter.getAffineDimExpr(currSrcDim++));
950 }
951 }
952 }
953 currDstDim++;
954 }
955
956 // If both iterators didn't reach the end, we have leftover dimentions which
957 // implies that we have a mismatch in shape.
958 return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
959 }
960
961 namespace {
962
963 template <typename SrcOp>
964 class PointwiseConverter : public OpRewritePattern<SrcOp> {
965 public:
966 using OpRewritePattern<SrcOp>::OpRewritePattern;
967
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const968 LogicalResult matchAndRewrite(SrcOp op,
969 PatternRewriter &rewriter) const final {
970 return elementwiseMatchAndRewriteHelper(op, rewriter);
971 }
972 };
973
974 class ReshapeConverterCollapse : public OpConversionPattern<tosa::ReshapeOp> {
975 public:
976 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
977
978 LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const979 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
980 ConversionPatternRewriter &rewriter) const final {
981 ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
982 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
983 bool isDynamic = !operandTy.hasStaticShape();
984
985 if (isDynamic && resultTy.getRank() != 1) {
986 return rewriter.notifyMatchFailure(
987 reshape, "Cannot collapse dynamic dims to more than one dimension");
988 }
989
990 if (operandTy == resultTy) {
991 rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
992 return success();
993 }
994
995 SmallVector<ReassociationExprs, 4> reassociationMap;
996 if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
997 resultTy.getShape(),
998 reassociationMap, isDynamic)) {
999 return rewriter.notifyMatchFailure(
1000 reshape,
1001 "tosa.reshape Attempting to collapse into an incompatible shape");
1002 }
1003
1004 SmallVector<int64_t> intermediateShape;
1005 if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1006 intermediateShape, isDynamic)) {
1007 return rewriter.notifyMatchFailure(
1008 reshape, "tosa.reshape Cannot collapse into given shape");
1009 }
1010
1011 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
1012 reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1013 return success();
1014 }
1015 };
1016
1017 class ReshapeConverterExpand : public OpConversionPattern<tosa::ReshapeOp> {
1018 public:
1019 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1020
1021 LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1022 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1023 ConversionPatternRewriter &rewriter) const final {
1024 ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
1025 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1026 bool isDynamic = !operandTy.hasStaticShape();
1027
1028 if (operandTy == resultTy) {
1029 rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1030 return success();
1031 }
1032
1033 if (isDynamic && operandTy.getRank() != 1) {
1034 return rewriter.notifyMatchFailure(
1035 reshape, "Cannot expand dynamic dims from more than one dimension");
1036 }
1037
1038 SmallVector<ReassociationExprs, 4> reassociationMap;
1039 if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
1040 operandTy.getShape(),
1041 reassociationMap, isDynamic)) {
1042 return rewriter.notifyMatchFailure(
1043 reshape,
1044 "tosa.reshape Attempting to expand into an incompatible shape");
1045 }
1046
1047 SmallVector<int64_t> intermediateShape;
1048 if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
1049 intermediateShape, isDynamic) ||
1050 intermediateShape != operandTy.getShape()) {
1051 return rewriter.notifyMatchFailure(
1052 reshape, "tosa.reshape Cannot expand into given shape");
1053 }
1054 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1055 reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1056 return success();
1057 }
1058 };
1059
1060 class ReshapeConverterCollapseExpand
1061 : public OpConversionPattern<tosa::ReshapeOp> {
1062 public:
1063 using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
1064
1065 LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const1066 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1067 ConversionPatternRewriter &rewriter) const final {
1068 ShapedType operandTy = adaptor.getInput1().getType().cast<ShapedType>();
1069 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1070 bool isDynamic = !operandTy.hasStaticShape();
1071
1072 if (operandTy == resultTy) {
1073 rewriter.replaceOp(reshape, adaptor.getOperands()[0]);
1074 return success();
1075 }
1076
1077 SmallVector<int64_t> intermediateShape;
1078 if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
1079 intermediateShape, isDynamic)) {
1080 return rewriter.notifyMatchFailure(
1081 reshape, "tosa.reshape Cannot identify an intermediate shape between "
1082 "the given two shapes");
1083 }
1084
1085 Value collapse = rewriter.create<tosa::ReshapeOp>(
1086 reshape.getLoc(),
1087 RankedTensorType::get(intermediateShape,
1088 reshape.getType().getElementType()),
1089 adaptor.getInput1());
1090 Value expand =
1091 rewriter.create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
1092 rewriter.replaceOp(reshape, expand);
1093
1094 return success();
1095 }
1096 };
1097
1098 class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1099 public:
1100 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1101
matchAndRewrite(tosa::TransposeOp op,PatternRewriter & rewriter) const1102 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1103 PatternRewriter &rewriter) const final {
1104 DenseIntElementsAttr perms;
1105 if (!matchPattern(op.getPerms(), m_Constant(&perms))) {
1106 return failure();
1107 }
1108
1109 auto loc = op.getLoc();
1110 auto input = op->getOperand(0);
1111 auto resultTy = op.getType().cast<ShapedType>();
1112
1113 SmallVector<Value> dynDims;
1114 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1115
1116 SmallVector<AffineExpr, 2> inputExprs;
1117 inputExprs.resize(resultTy.getRank());
1118 auto operandTy = input.getType().cast<ShapedType>();
1119 for (const auto &permutation : llvm::enumerate(perms.getValues<APInt>())) {
1120 auto index = permutation.index();
1121 auto value = permutation.value().getZExtValue();
1122 if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1123 dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1124 }
1125 inputExprs[value] = rewriter.getAffineDimExpr(index);
1126 }
1127
1128 SmallVector<Value> filteredDims = condenseValues(dynDims);
1129
1130 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1131 loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1132
1133 SmallVector<AffineMap, 2> affineMaps = {
1134 AffineMap::get(resultTy.getRank(), /*symbolCount=*/0, inputExprs,
1135 rewriter.getContext()),
1136 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1137
1138 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1139 op, resultTy, op.getInput1(), ValueRange{initTensor}, affineMaps,
1140 getNParallelLoopsAttrs(resultTy.getRank()),
1141 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1142 nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1143 });
1144 return success();
1145 }
1146 };
1147
1148 class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1149 public:
1150 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1151
matchAndRewrite(tosa::RescaleOp op,PatternRewriter & rewriter) const1152 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1153 PatternRewriter &rewriter) const final {
1154 auto loc = op.getLoc();
1155 auto input = op.getInput();
1156 auto inputTy = op.getInput().getType().cast<ShapedType>();
1157 auto outputTy = op.getOutput().getType().cast<ShapedType>();
1158 unsigned rank = inputTy.getRank();
1159
1160 // This is an illegal configuration. terminate and log an error
1161 if (op.getDoubleRound() && !op.getScale32())
1162 return rewriter.notifyMatchFailure(
1163 op, "tosa.rescale requires scale32 for double_round to be true");
1164
1165 auto dynamicDimsOr =
1166 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1167 if (!dynamicDimsOr.has_value())
1168 return failure();
1169 SmallVector<Value> dynamicDims = dynamicDimsOr.value();
1170
1171 // The shift and multiplier values.
1172 SmallVector<int32_t> multiplierValues;
1173 getValuesFromIntArrayAttribute(op.getMultiplier(), multiplierValues);
1174
1175 SmallVector<int8_t> shiftValues;
1176 getValuesFromIntArrayAttribute(op.getShift(), shiftValues);
1177
1178 // If we shift by more than the bitwidth, this just sets to 0.
1179 for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1180 if (shiftValues[i] > 63) {
1181 shiftValues[i] = 0;
1182 multiplierValues[i] = 0;
1183 }
1184 }
1185
1186 // Double round only occurs if shift is greater than 31, check that this
1187 // is ever true.
1188 bool doubleRound =
1189 op.getDoubleRound() &&
1190 llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1191
1192 SmallVector<AffineMap> indexingMaps = {
1193 rewriter.getMultiDimIdentityMap(rank)};
1194 SmallVector<Value, 4> genericInputs = {input};
1195
1196 // If we are rescaling per-channel then we need to store the multiplier
1197 // values in a buffer.
1198 Value multiplierConstant;
1199 int64_t multiplierArg = 0;
1200 if (multiplierValues.size() == 1) {
1201 multiplierConstant = rewriter.create<arith::ConstantOp>(
1202 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1203 } else {
1204 SmallVector<AffineExpr, 2> multiplierExprs{
1205 rewriter.getAffineDimExpr(rank - 1)};
1206 auto multiplierType =
1207 RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1208 rewriter.getI32Type());
1209 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1210 loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1211
1212 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1213 /*symbolCount=*/0, multiplierExprs,
1214 rewriter.getContext()));
1215
1216 multiplierArg = indexingMaps.size() - 1;
1217 }
1218
1219 // If we are rescaling per-channel then we need to store the shift
1220 // values in a buffer.
1221 Value shiftConstant;
1222 int64_t shiftArg = 0;
1223 if (shiftValues.size() == 1) {
1224 shiftConstant = rewriter.create<arith::ConstantOp>(
1225 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1226 } else {
1227 SmallVector<AffineExpr, 2> shiftExprs = {
1228 rewriter.getAffineDimExpr(rank - 1)};
1229 auto shiftType =
1230 RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1231 rewriter.getIntegerType(8));
1232 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1233 loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1234 indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1235 /*symbolCount=*/0, shiftExprs,
1236 rewriter.getContext()));
1237 shiftArg = indexingMaps.size() - 1;
1238 }
1239
1240 // Indexing maps for output values.
1241 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1242
1243 // Construct the indexing maps needed for linalg.generic ops.
1244 Value initTensor = rewriter.create<linalg::InitTensorOp>(
1245 loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
1246
1247 auto linalgOp = rewriter.create<linalg::GenericOp>(
1248 loc, outputTy, genericInputs, ValueRange{initTensor}, indexingMaps,
1249 getNParallelLoopsAttrs(rank),
1250 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1251 ValueRange blockArgs) {
1252 Value value = blockArgs[0];
1253 Type valueTy = value.getType();
1254
1255 // For now we do all of our math in 64-bit. This is not optimal but
1256 // should be correct for now, consider computing correct bit depth
1257 // later.
1258 int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;
1259
1260 auto inputZp = createConstFromIntAttribute<int32_t>(
1261 op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
1262 nestedBuilder);
1263 auto outputZp = createConstFromIntAttribute<int32_t>(
1264 op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1265
1266 Value multiplier = multiplierConstant ? multiplierConstant
1267 : blockArgs[multiplierArg];
1268 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1269
1270 if (valueTy.getIntOrFloatBitWidth() < 32) {
1271 if (valueTy.isUnsignedInteger()) {
1272 value = nestedBuilder
1273 .create<UnrealizedConversionCastOp>(
1274 nestedLoc,
1275 nestedBuilder.getIntegerType(
1276 valueTy.getIntOrFloatBitWidth()),
1277 value)
1278 .getResult(0);
1279 value = nestedBuilder.create<arith::ExtUIOp>(
1280 nestedLoc, nestedBuilder.getI32Type(), value);
1281 } else {
1282 value = nestedBuilder.create<arith::ExtSIOp>(
1283 nestedLoc, nestedBuilder.getI32Type(), value);
1284 }
1285 }
1286
1287 value =
1288 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1289
1290 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1291 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1292 nestedBuilder.getBoolAttr(doubleRound));
1293
1294 // Move to the new zero-point.
1295 value =
1296 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1297
1298 // Saturate to the output size.
1299 IntegerType outIntType =
1300 blockArgs.back().getType().cast<IntegerType>();
1301 unsigned outBitWidth = outIntType.getWidth();
1302
1303 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1304 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1305
1306 // Unsigned integers have a difference output value.
1307 if (outIntType.isUnsignedInteger()) {
1308 intMin = 0;
1309 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1310 }
1311
1312 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1313 loc, nestedBuilder.getI32IntegerAttr(intMin));
1314 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1315 loc, nestedBuilder.getI32IntegerAttr(intMax));
1316
1317 value = clampHelper<arith::CmpIOp>(
1318 nestedLoc, value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
1319 nestedBuilder);
1320
1321 if (outIntType.getWidth() < 32) {
1322 value = nestedBuilder.create<arith::TruncIOp>(
1323 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1324 value);
1325
1326 if (outIntType.isUnsignedInteger()) {
1327 value = nestedBuilder
1328 .create<UnrealizedConversionCastOp>(nestedLoc,
1329 outIntType, value)
1330 .getResult(0);
1331 }
1332 }
1333
1334 nestedBuilder.create<linalg::YieldOp>(loc, value);
1335 });
1336
1337 rewriter.replaceOp(op, linalgOp->getResults());
1338 return success();
1339 }
1340 };
1341
1342 class ResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
1343 public:
1344 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1345
matchAndRewrite(tosa::ResizeOp op,PatternRewriter & rewriter) const1346 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1347 PatternRewriter &rewriter) const final {
1348 Location loc = op.getLoc();
1349 auto input = op.getInput();
1350 auto inputTy = input.getType().cast<ShapedType>();
1351 auto resultTy = op.getType().cast<ShapedType>();
1352 auto resultElementTy = resultTy.getElementType();
1353
1354 auto imageH = inputTy.getShape()[1];
1355 auto imageW = inputTy.getShape()[2];
1356
1357 auto dynamicDimsOr =
1358 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
1359 if (!dynamicDimsOr.has_value())
1360 return failure();
1361 SmallVector<Value> dynamicDims = dynamicDimsOr.value();
1362
1363 if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR")
1364 return failure();
1365
1366 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1367 loc, dynamicDims, resultTy.getShape(), resultElementTy);
1368
1369 SmallVector<AffineMap, 2> affineMaps = {
1370 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1371
1372 auto genericOp = rewriter.create<linalg::GenericOp>(
1373 loc, resultTy, ValueRange({}), ValueRange{initTensor}, affineMaps,
1374 getNParallelLoopsAttrs(resultTy.getRank()));
1375 rewriter.replaceOp(op, genericOp.getResult(0));
1376
1377 OpBuilder::InsertionGuard regionGuard(rewriter);
1378 rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1379 TypeRange({resultElementTy}), loc);
1380 Value batch = rewriter.create<linalg::IndexOp>(loc, 0);
1381 Value y = rewriter.create<linalg::IndexOp>(loc, 1);
1382 Value x = rewriter.create<linalg::IndexOp>(loc, 2);
1383 Value channel = rewriter.create<linalg::IndexOp>(loc, 3);
1384
1385 auto hwMin =
1386 rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
1387 auto hMax = rewriter.create<arith::ConstantOp>(
1388 loc, rewriter.getI32IntegerAttr(imageH - 1));
1389 auto wMax = rewriter.create<arith::ConstantOp>(
1390 loc, rewriter.getI32IntegerAttr(imageW - 1));
1391
1392 Value inY =
1393 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), y);
1394 Value inX =
1395 rewriter.create<arith::IndexCastOp>(loc, rewriter.getI32Type(), x);
1396
1397 int32_t shift = op.getShift();
1398 bool floatingPointMode = shift == 0;
1399
1400 Value yStride, xStride, yOffset, xOffset;
1401 if (floatingPointMode) {
1402 yStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[0]);
1403 xStride = rewriter.create<arith::ConstantOp>(loc, op.getStrideFp()[1]);
1404 yOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[0]);
1405 xOffset = rewriter.create<arith::ConstantOp>(loc, op.getOffsetFp()[1]);
1406 } else {
1407 SmallVector<int32_t> stride, offset;
1408 getValuesFromIntArrayAttribute(op.getStride(), stride);
1409 getValuesFromIntArrayAttribute(op.getOffset(), offset);
1410
1411 yStride = rewriter.create<arith::ConstantOp>(
1412 loc, rewriter.getI32IntegerAttr(stride[0]));
1413 xStride = rewriter.create<arith::ConstantOp>(
1414 loc, rewriter.getI32IntegerAttr(stride[1]));
1415 yOffset = rewriter.create<arith::ConstantOp>(
1416 loc, rewriter.getI32IntegerAttr(offset[0]));
1417 xOffset = rewriter.create<arith::ConstantOp>(
1418 loc, rewriter.getI32IntegerAttr(offset[1]));
1419 }
1420
1421 // Compute the the integer index and partial offset.
1422 // x = x * stride + offset;
1423 // ix = floor(x)
1424 // dx = x - ix
1425 Value ix, iy, dx, dy;
1426 if (floatingPointMode) {
1427 Value y =
1428 rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inY);
1429 Value x =
1430 rewriter.create<arith::UIToFPOp>(loc, rewriter.getF32Type(), inX);
1431
1432 y = rewriter.create<arith::MulFOp>(loc, y, yStride);
1433 x = rewriter.create<arith::MulFOp>(loc, x, xStride);
1434
1435 y = rewriter.create<arith::AddFOp>(loc, y, yOffset);
1436 x = rewriter.create<arith::AddFOp>(loc, x, xOffset);
1437
1438 iy = rewriter.create<math::FloorOp>(loc, y);
1439 ix = rewriter.create<math::FloorOp>(loc, x);
1440
1441 dy = rewriter.create<arith::SubFOp>(loc, y, iy);
1442 dx = rewriter.create<arith::SubFOp>(loc, x, ix);
1443
1444 iy = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), iy);
1445 ix = rewriter.create<arith::FPToSIOp>(loc, rewriter.getI32Type(), ix);
1446 } else {
1447 Value shiftVal = rewriter.create<arith::ConstantOp>(
1448 loc, rewriter.getI32IntegerAttr(shift));
1449
1450 Value y = rewriter.create<arith::MulIOp>(loc, inY, yStride);
1451 Value x = rewriter.create<arith::MulIOp>(loc, inX, xStride);
1452
1453 y = rewriter.create<arith::AddIOp>(loc, y, yOffset);
1454 x = rewriter.create<arith::AddIOp>(loc, x, xOffset);
1455
1456 iy = rewriter.create<arith::ShRSIOp>(loc, y, shiftVal);
1457 ix = rewriter.create<arith::ShRSIOp>(loc, x, shiftVal);
1458
1459 Value yTrunc = rewriter.create<arith::ShLIOp>(loc, iy, shiftVal);
1460 Value xTrunc = rewriter.create<arith::ShLIOp>(loc, ix, shiftVal);
1461
1462 dy = rewriter.create<arith::SubIOp>(loc, y, yTrunc);
1463 dx = rewriter.create<arith::SubIOp>(loc, x, xTrunc);
1464 }
1465
1466 if (op.getMode() == "NEAREST_NEIGHBOR") {
1467 Value yPred, xPred;
1468 // Round the index position towards the closest pixel location.
1469 if (floatingPointMode) {
1470 auto halfVal = rewriter.create<arith::ConstantOp>(
1471 loc, rewriter.getF32FloatAttr(0.5f));
1472 yPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1473 dy, halfVal);
1474 xPred = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1475 dx, halfVal);
1476 } else {
1477 auto halfVal = rewriter.create<arith::ConstantOp>(
1478 loc, rewriter.getI32IntegerAttr(1 << (shift - 1)));
1479 yPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1480 dy, halfVal);
1481 xPred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1482 dx, halfVal);
1483 }
1484
1485 auto zeroVal = rewriter.create<arith::ConstantOp>(
1486 loc, rewriter.getI32IntegerAttr(0));
1487 auto oneVal = rewriter.create<arith::ConstantOp>(
1488 loc, rewriter.getI32IntegerAttr(1));
1489
1490 auto yOffset =
1491 rewriter.create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
1492 auto xOffset =
1493 rewriter.create<arith::SelectOp>(loc, xPred, oneVal, zeroVal);
1494
1495 iy = rewriter.create<arith::AddIOp>(loc, iy, yOffset);
1496 ix = rewriter.create<arith::AddIOp>(loc, ix, xOffset);
1497
1498 // Clamp the to be within the bounds of the input image.
1499
1500 iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
1501 arith::CmpIPredicate::slt, rewriter);
1502 ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
1503 arith::CmpIPredicate::slt, rewriter);
1504
1505 // Read the value from the input array.
1506 iy =
1507 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), iy);
1508 ix =
1509 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), ix);
1510
1511 Value result = rewriter.create<tensor::ExtractOp>(
1512 loc, input, ValueRange{batch, iy, ix, channel});
1513
1514 rewriter.create<linalg::YieldOp>(loc, result);
1515
1516 return success();
1517 }
1518
1519 if (op.getMode() == "BILINEAR") {
1520 Value y0 = iy;
1521 Value x0 = ix;
1522
1523 auto oneVal = rewriter.create<arith::ConstantOp>(
1524 loc, rewriter.getI32IntegerAttr(1));
1525 Value y1 = rewriter.create<arith::AddIOp>(loc, y0, oneVal);
1526 Value x1 = rewriter.create<arith::AddIOp>(loc, x0, oneVal);
1527
1528 y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
1529 arith::CmpIPredicate::slt, rewriter);
1530 y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
1531 arith::CmpIPredicate::slt, rewriter);
1532
1533 x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
1534 arith::CmpIPredicate::slt, rewriter);
1535 x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
1536 arith::CmpIPredicate::slt, rewriter);
1537
1538 y0 =
1539 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y0);
1540 y1 =
1541 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), y1);
1542 x0 =
1543 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x0);
1544 x1 =
1545 rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), x1);
1546
1547 Value y0x0 = rewriter.create<tensor::ExtractOp>(
1548 loc, input, ValueRange{batch, y0, x0, channel});
1549 Value y0x1 = rewriter.create<tensor::ExtractOp>(
1550 loc, input, ValueRange{batch, y0, x1, channel});
1551 Value y1x0 = rewriter.create<tensor::ExtractOp>(
1552 loc, input, ValueRange{batch, y1, x0, channel});
1553 Value y1x1 = rewriter.create<tensor::ExtractOp>(
1554 loc, input, ValueRange{batch, y1, x1, channel});
1555
1556 if (floatingPointMode) {
1557 auto oneVal = rewriter.create<arith::ConstantOp>(
1558 loc, rewriter.getF32FloatAttr(1.f));
1559 Value rightPart = dx;
1560 Value leftPart = rewriter.create<arith::SubFOp>(loc, oneVal, dx);
1561
1562 y0x0 = rewriter.create<arith::MulFOp>(loc, y0x0, leftPart);
1563 y0x1 = rewriter.create<arith::MulFOp>(loc, y0x1, rightPart);
1564 Value topAcc = rewriter.create<arith::AddFOp>(loc, y0x0, y0x1);
1565
1566 y1x0 = rewriter.create<arith::MulFOp>(loc, y1x0, leftPart);
1567 y1x1 = rewriter.create<arith::MulFOp>(loc, y1x1, rightPart);
1568 Value bottomAcc = rewriter.create<arith::AddFOp>(loc, y1x0, y1x1);
1569
1570 Value bottomPart = dy;
1571 Value topPart = rewriter.create<arith::SubFOp>(loc, oneVal, dy);
1572 topAcc = rewriter.create<arith::MulFOp>(loc, topAcc, topPart);
1573 bottomAcc = rewriter.create<arith::MulFOp>(loc, bottomAcc, bottomPart);
1574 Value result = rewriter.create<arith::AddFOp>(loc, topAcc, bottomAcc);
1575
1576 rewriter.create<linalg::YieldOp>(loc, result);
1577 return success();
1578 }
1579 y0x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
1580 y0x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
1581 y1x0 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
1582 y1x1 = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
1583
1584 if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1585 dx = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dx);
1586 dy = rewriter.create<arith::ExtSIOp>(loc, resultElementTy, dy);
1587 }
1588
1589 auto unitVal = rewriter.create<arith::ConstantOp>(
1590 loc, rewriter.getIntegerAttr(resultElementTy, 1LL << shift));
1591 Value rightPart = dx;
1592 Value leftPart = rewriter.create<arith::SubIOp>(loc, unitVal, dx);
1593
1594 y0x0 = rewriter.create<arith::MulIOp>(loc, y0x0, leftPart);
1595 y0x1 = rewriter.create<arith::MulIOp>(loc, y0x1, rightPart);
1596 Value topAcc = rewriter.create<arith::AddIOp>(loc, y0x0, y0x1);
1597
1598 y1x0 = rewriter.create<arith::MulIOp>(loc, y1x0, leftPart);
1599 y1x1 = rewriter.create<arith::MulIOp>(loc, y1x1, rightPart);
1600 Value bottomAcc = rewriter.create<arith::AddIOp>(loc, y1x0, y1x1);
1601
1602 Value bottomPart = dy;
1603 Value topPart = rewriter.create<arith::SubIOp>(loc, unitVal, dy);
1604 topAcc = rewriter.create<arith::MulIOp>(loc, topAcc, topPart);
1605 bottomAcc = rewriter.create<arith::MulIOp>(loc, bottomAcc, bottomPart);
1606 Value result = rewriter.create<arith::AddIOp>(loc, topAcc, bottomAcc);
1607
1608 rewriter.create<linalg::YieldOp>(loc, result);
1609 return success();
1610 }
1611 return failure();
1612 }
1613 };
1614
1615 // At the codegen level any identity operations should be removed. Any cases
1616 // where identity is load-bearing (e.g. cross device computation) should be
1617 // handled before lowering to codegen.
1618 template <typename SrcOp>
1619 class IdentityNConverter : public OpRewritePattern<SrcOp> {
1620 public:
1621 using OpRewritePattern<SrcOp>::OpRewritePattern;
1622
matchAndRewrite(SrcOp op,PatternRewriter & rewriter) const1623 LogicalResult matchAndRewrite(SrcOp op,
1624 PatternRewriter &rewriter) const final {
1625 rewriter.replaceOp(op, op.getOperation()->getOperands());
1626 return success();
1627 }
1628 };
1629
1630 template <typename SrcOp>
1631 class ReduceConverter : public OpRewritePattern<SrcOp> {
1632 public:
1633 using OpRewritePattern<SrcOp>::OpRewritePattern;
1634
matchAndRewrite(SrcOp reduceOp,PatternRewriter & rewriter) const1635 LogicalResult matchAndRewrite(SrcOp reduceOp,
1636 PatternRewriter &rewriter) const final {
1637 return reduceMatchAndRewriteHelper(reduceOp, reduceOp.getAxis(), rewriter);
1638 }
1639 };
1640
1641 struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
1642 using OpConversionPattern<tosa::ConcatOp>::OpConversionPattern;
1643
1644 LogicalResult
matchAndRewrite__anonadc6429f0411::ConcatConverter1645 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
1646 ConversionPatternRewriter &rewriter) const override {
1647 auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
1648 auto resultType = op.getType().dyn_cast<RankedTensorType>();
1649
1650 Location loc = op.getLoc();
1651 int axis = op.getAxis();
1652 Value axisValue = rewriter.createOrFold<arith::ConstantOp>(
1653 loc, rewriter.getIndexAttr(axis));
1654 int rank = resultType.getRank();
1655 SmallVector<Value, 3> offsets, sizes, strides;
1656 sizes.reserve(rank);
1657 strides.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 1));
1658 offsets.resize(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
1659
1660 SmallVector<Value> dynDims;
1661 for (int i = 0; i < rank; ++i) {
1662 sizes.push_back(rewriter.createOrFold<tensor::DimOp>(
1663 loc, adaptor.getOperands()[0], i));
1664 if (inputType.isDynamicDim(i)) {
1665 dynDims.push_back(
1666 rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
1667 }
1668 }
1669
1670 Value resultDimSize = sizes[axis];
1671 for (auto arg : adaptor.getOperands().drop_front()) {
1672 auto size = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1673 resultDimSize =
1674 rewriter.createOrFold<arith::AddIOp>(loc, resultDimSize, size);
1675 }
1676 sizes[axis] = resultDimSize;
1677
1678 Value init = rewriter.create<linalg::InitTensorOp>(
1679 loc, dynDims, resultType.getShape(), resultType.getElementType());
1680
1681 Value zeroVal = rewriter.createOrFold<arith::ConstantOp>(
1682 loc, rewriter.getZeroAttr(resultType.getElementType()));
1683 Value result =
1684 rewriter
1685 .create<linalg::FillOp>(loc, ValueRange{zeroVal}, ValueRange{init})
1686 .result();
1687
1688 auto toOpFoldResult = [](Value v) -> OpFoldResult {
1689 auto op = v.getDefiningOp<arith::ConstantIndexOp>();
1690 if (!op)
1691 return v;
1692 return op.getValue();
1693 };
1694 for (auto arg : adaptor.getOperands()) {
1695 sizes[axis] = rewriter.createOrFold<tensor::DimOp>(loc, arg, axisValue);
1696 result = rewriter.createOrFold<tensor::InsertSliceOp>(
1697 loc, arg, result,
1698 llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
1699 llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
1700 llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
1701 offsets[axis] =
1702 rewriter.createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
1703 }
1704 rewriter.replaceOp(op, result);
1705 return success();
1706 }
1707 };
1708
1709 class ReverseConverter : public OpRewritePattern<tosa::ReverseOp> {
1710 public:
1711 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
1712
matchAndRewrite(tosa::ReverseOp op,PatternRewriter & rewriter) const1713 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1714 PatternRewriter &rewriter) const final {
1715 auto loc = op.getLoc();
1716 Value input = op.getInput();
1717 auto inputTy = input.getType().template cast<ShapedType>();
1718 auto resultTy = op.getType().template cast<ShapedType>();
1719 auto axis = op.getAxis();
1720
1721 SmallVector<Value> dynDims;
1722 for (int i = 0; i < inputTy.getRank(); i++) {
1723 if (inputTy.isDynamicDim(i)) {
1724 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1725 }
1726 }
1727
1728 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1729
1730 // First fill the output buffer with the init value.
1731 auto initTensor = rewriter
1732 .create<linalg::InitTensorOp>(
1733 loc, ArrayRef<Value>({dynDims}),
1734 inputTy.getShape(), inputTy.getElementType())
1735 .result();
1736 SmallVector<AffineMap, 2> affineMaps = {
1737 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1738
1739 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1740 op, resultTy, ArrayRef<Value>({}), ValueRange{initTensor}, affineMaps,
1741 getNParallelLoopsAttrs(resultTy.getRank()),
1742 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1743 llvm::SmallVector<Value> indices;
1744 for (unsigned int i = 0; i < inputTy.getRank(); i++) {
1745 auto index =
1746 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1747 if (i == axis) {
1748 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1749 auto sizeMinusOne =
1750 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1751 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1752 index);
1753 }
1754
1755 indices.push_back(index);
1756 }
1757
1758 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1759 nestedLoc, input, indices);
1760 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1761 extract.getResult());
1762 });
1763 return success();
1764 }
1765 };
1766
1767 // This converter translate a tile operation to a reshape, broadcast, reshape.
1768 // The first reshape minimally expands each tiled dimension to include a
1769 // proceding size-1 dim. This dim is then broadcasted to the appropriate
1770 // multiple.
1771 struct TileConverter : public OpConversionPattern<tosa::TileOp> {
1772 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
1773
1774 LogicalResult
matchAndRewrite__anonadc6429f0411::TileConverter1775 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1776 ConversionPatternRewriter &rewriter) const override {
1777 auto loc = op.getLoc();
1778 auto input = op.getInput1();
1779 auto inputTy = input.getType().cast<ShapedType>();
1780 auto inputShape = inputTy.getShape();
1781 auto resultTy = op.getType().cast<ShapedType>();
1782 auto elementTy = inputTy.getElementType();
1783 int64_t rank = inputTy.getRank();
1784
1785 SmallVector<int64_t> multiples;
1786 getValuesFromIntArrayAttribute(op.getMultiples(), multiples);
1787
1788 // Broadcast the newly added dimensions to their appropriate multiple.
1789 SmallVector<int64_t, 2> genericShape;
1790 for (int i = 0; i < rank; i++) {
1791 genericShape.push_back(multiples[i]);
1792 genericShape.push_back(inputShape[i]);
1793 }
1794
1795 SmallVector<Value> dynDims;
1796 for (int i = 0; i < inputTy.getRank(); i++) {
1797 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1798 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1799 }
1800 }
1801
1802 auto initTensor = rewriter.create<linalg::InitTensorOp>(
1803 op.getLoc(), dynDims, genericShape, elementTy);
1804
1805 // We needs to map the input shape to the non-broadcasted dimensions.
1806 SmallVector<AffineExpr, 4> dimExprs;
1807 dimExprs.reserve(rank);
1808 for (unsigned i = 0; i < rank; ++i)
1809 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
1810
1811 auto readAffineMap =
1812 AffineMap::get(/*dimCount=*/rank * 2, /*symbolCount=*/0, dimExprs,
1813 rewriter.getContext());
1814
1815 SmallVector<AffineMap, 2> affineMaps = {
1816 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
1817
1818 auto genericOp = rewriter.create<linalg::GenericOp>(
1819 loc, RankedTensorType::get(genericShape, elementTy), input,
1820 ValueRange{initTensor}, affineMaps,
1821 getNParallelLoopsAttrs(genericShape.size()),
1822 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
1823 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1824 });
1825
1826 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
1827 op, resultTy, genericOp.getResult(0),
1828 rewriter.getI64ArrayAttr(resultTy.getShape()));
1829 return success();
1830 }
1831 };
1832
1833 class PadConverter : public OpRewritePattern<tosa::PadOp> {
1834 public:
1835 using OpRewritePattern<tosa::PadOp>::OpRewritePattern;
1836
matchAndRewrite(tosa::PadOp padOp,PatternRewriter & rewriter) const1837 LogicalResult matchAndRewrite(tosa::PadOp padOp,
1838 PatternRewriter &rewriter) const final {
1839 auto loc = padOp.getLoc();
1840 auto input = padOp.getInput1();
1841 auto padding = padOp.getPadding();
1842
1843 ShapedType inputTy = input.getType().cast<ShapedType>();
1844 Type elementTy = inputTy.getElementType();
1845 int64_t rank = inputTy.getRank();
1846
1847 // Setup the default constantAttr.
1848
1849 Value padConstant;
1850
1851 if (padOp.getPadConst()) {
1852 padConstant = rewriter.createOrFold<tensor::ExtractOp>(
1853 loc, padOp.getPadConst(), ValueRange({}));
1854 } else {
1855 Attribute constantAttr;
1856 if (elementTy.isa<FloatType>()) {
1857 constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
1858 } else if (elementTy.isa<IntegerType>() && !padOp.getQuantizationInfo()) {
1859 constantAttr = rewriter.getIntegerAttr(elementTy, 0);
1860 } else if (elementTy.isa<IntegerType>() && padOp.getQuantizationInfo()) {
1861 int64_t value = padOp.getQuantizationInfo()->getInputZp();
1862 constantAttr = rewriter.getIntegerAttr(elementTy, value);
1863 }
1864 if (constantAttr)
1865 padConstant = rewriter.create<arith::ConstantOp>(loc, constantAttr);
1866 }
1867
1868 if (!padConstant) {
1869 return rewriter.notifyMatchFailure(
1870 padOp, "tosa.pad was unable to determine the pad constant value.");
1871 }
1872
1873 Value lowIndex =
1874 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
1875 Value highIndex =
1876 rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
1877
1878 SmallVector<OpFoldResult, 3> lowValues;
1879 SmallVector<OpFoldResult, 3> highValues;
1880
1881 lowValues.reserve(rank);
1882 highValues.reserve(rank);
1883
1884 for (int i = 0; i < rank; i++) {
1885 Value inputIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, i);
1886 Value lowVal = rewriter.createOrFold<tensor::ExtractOp>(
1887 loc, padding, ValueRange({inputIndex, lowIndex}));
1888 Value highVal = rewriter.createOrFold<tensor::ExtractOp>(
1889 loc, padding, ValueRange({inputIndex, highIndex}));
1890
1891 lowVal = rewriter.createOrFold<arith::IndexCastOp>(
1892 loc, rewriter.getIndexType(), lowVal);
1893 highVal = rewriter.createOrFold<arith::IndexCastOp>(
1894 loc, rewriter.getIndexType(), highVal);
1895
1896 lowValues.push_back(lowVal);
1897 highValues.push_back(highVal);
1898 }
1899
1900 auto newPadOp = tensor::createPadScalarOp(
1901 padOp.getType(), input, padConstant, lowValues, highValues,
1902 /*nofold=*/false, loc, rewriter);
1903
1904 rewriter.replaceOp(padOp, newPadOp.getResult());
1905 return success();
1906 }
1907 };
1908
1909 // Tosa argmax lowering represents the ArgMax op as an linalg.indexed_generic
1910 // op, producing two output buffers.
1911 //
1912 // The first output buffer contains the index of the found maximum value. It is
1913 // initialized to 0 and is resulting integer type.
1914 //
1915 // The second output buffer contains the maximum value found. It is initialized
1916 // to the minimum representable value of the input element type. After being
1917 // populated by indexed_generic, this buffer is disgarded as only the index is
1918 // requested.
1919 //
1920 // The indexed_generic op updates both the maximum value and index if the
1921 // current value exceeds the running max.
1922 class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
1923 public:
1924 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
1925
matchAndRewrite(tosa::ArgMaxOp argmaxOp,PatternRewriter & rewriter) const1926 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1927 PatternRewriter &rewriter) const final {
1928 auto loc = argmaxOp.getLoc();
1929 Value input = argmaxOp.getInput();
1930 auto inputTy = input.getType().cast<ShapedType>();
1931 auto resultTy = argmaxOp.getOutput().getType().cast<ShapedType>();
1932 auto inElementTy = inputTy.getElementType();
1933 auto outElementTy = resultTy.getElementType();
1934 int axis = argmaxOp.getAxis();
1935 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1936
1937 if (!outElementTy.isa<IntegerType>())
1938 return rewriter.notifyMatchFailure(
1939 argmaxOp,
1940 "tosa.arg_max to linalg.* requires integer-like result type");
1941
1942 SmallVector<Value> dynDims;
1943 for (int i = 0; i < inputTy.getRank(); i++) {
1944 if (inputTy.isDynamicDim(i) && i != axis) {
1945 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1946 }
1947 }
1948
1949 // First fill the output buffer for the index.
1950 auto initTensorIdx =
1951 rewriter
1952 .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
1953 outElementTy)
1954 .result();
1955 auto fillValueIdx = rewriter.create<arith::ConstantOp>(
1956 loc, rewriter.getIntegerAttr(outElementTy, 0));
1957 auto filledTensorIdx =
1958 rewriter
1959 .create<linalg::FillOp>(loc, ValueRange{fillValueIdx},
1960 ValueRange{initTensorIdx})
1961 .result();
1962
1963 // Second fill the output buffer for the running max.
1964 auto initTensorMax = rewriter
1965 .create<linalg::InitTensorOp>(
1966 loc, dynDims, resultTy.getShape(), inElementTy)
1967 .result();
1968 auto fillValueMaxAttr =
1969 createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter);
1970
1971 if (!fillValueMaxAttr)
1972 return rewriter.notifyMatchFailure(
1973 argmaxOp, "unsupported tosa.argmax element type");
1974
1975 auto fillValueMax =
1976 rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr);
1977 auto filledTensorMax =
1978 rewriter
1979 .create<linalg::FillOp>(loc, ValueRange{fillValueMax},
1980 ValueRange{initTensorMax})
1981 .result();
1982
1983 // We need to reduce along the arg-max axis, with parallel operations along
1984 // the rest.
1985 SmallVector<StringRef, 4> iteratorTypes;
1986 iteratorTypes.resize(inputTy.getRank(), getParallelIteratorTypeName());
1987 iteratorTypes[axis] = getReductionIteratorTypeName();
1988
1989 SmallVector<AffineExpr, 2> srcExprs;
1990 SmallVector<AffineExpr, 2> dstExprs;
1991 for (int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
1992 srcExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1993 if (axis != i)
1994 dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext()));
1995 }
1996
1997 bool didEncounterError = false;
1998 auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs});
1999 auto linalgOp = rewriter.create<linalg::GenericOp>(
2000 loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2001 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2002 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2003 ValueRange blockArgs) {
2004 auto newValue = blockArgs[0];
2005 auto oldIndex = blockArgs[1];
2006 auto oldValue = blockArgs[2];
2007
2008 Value newIndex = rewriter.create<arith::IndexCastOp>(
2009 nestedLoc, oldIndex.getType(),
2010 rewriter.create<linalg::IndexOp>(loc, axis));
2011
2012 Value predicate;
2013 if (inElementTy.isa<FloatType>()) {
2014 predicate = rewriter.create<arith::CmpFOp>(
2015 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2016 } else if (inElementTy.isa<IntegerType>()) {
2017 predicate = rewriter.create<arith::CmpIOp>(
2018 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2019 } else {
2020 didEncounterError = true;
2021 return;
2022 }
2023
2024 auto resultMax = rewriter.create<arith::SelectOp>(
2025 nestedLoc, predicate, newValue, oldValue);
2026 auto resultIndex = rewriter.create<arith::SelectOp>(
2027 nestedLoc, predicate, newIndex, oldIndex);
2028 nestedBuilder.create<linalg::YieldOp>(
2029 nestedLoc, ValueRange({resultIndex, resultMax}));
2030 });
2031
2032 if (didEncounterError)
2033 return rewriter.notifyMatchFailure(
2034 argmaxOp, "unsupported tosa.argmax element type");
2035
2036 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2037 return success();
2038 }
2039 };
2040
2041 class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
2042 public:
2043 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2044 LogicalResult
matchAndRewrite(tosa::GatherOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const2045 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2046 ConversionPatternRewriter &rewriter) const final {
2047 auto input = adaptor.getOperands()[0];
2048 auto indices = adaptor.getOperands()[1];
2049
2050 auto resultTy = op.getType().cast<ShapedType>();
2051
2052 auto dynamicDimsOr = checkHasDynamicBatchDims(
2053 rewriter, op, {input, indices, op.getOutput()});
2054 if (!dynamicDimsOr.has_value())
2055 return failure();
2056 SmallVector<Value> dynamicDims = dynamicDimsOr.value();
2057
2058 auto resultElementTy = resultTy.getElementType();
2059
2060 auto loc = op.getLoc();
2061
2062 auto initTensor =
2063 rewriter
2064 .create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
2065 resultElementTy)
2066 .result();
2067
2068 SmallVector<AffineMap, 2> affineMaps = {
2069 AffineMap::get(
2070 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
2071 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2072 rewriter.getContext()),
2073 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2074
2075 auto genericOp = rewriter.create<linalg::GenericOp>(
2076 loc, ArrayRef<Type>({resultTy}), ValueRange{indices},
2077 ValueRange{initTensor}, affineMaps,
2078 getNParallelLoopsAttrs(resultTy.getRank()),
2079 [&](OpBuilder &b, Location loc, ValueRange args) {
2080 auto indexValue = args[0];
2081 auto index0 = rewriter.create<linalg::IndexOp>(loc, 0);
2082 Value index1 = rewriter.create<arith::IndexCastOp>(
2083 loc, rewriter.getIndexType(), indexValue);
2084 auto index2 = rewriter.create<linalg::IndexOp>(loc, 2);
2085 Value extract = rewriter.create<tensor::ExtractOp>(
2086 loc, input, ValueRange{index0, index1, index2});
2087 rewriter.create<linalg::YieldOp>(loc, extract);
2088 });
2089 rewriter.replaceOp(op, genericOp.getResult(0));
2090 return success();
2091 }
2092 };
2093
2094 // Lowerings the TableOp to a series of gathers and numerica operations. This
2095 // includes interpolation between the high/low values. For the I8 varient, this
2096 // simplifies to a single gather operation.
2097 class TableConverter : public OpRewritePattern<tosa::TableOp> {
2098 public:
2099 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2100
matchAndRewrite(tosa::TableOp op,PatternRewriter & rewriter) const2101 LogicalResult matchAndRewrite(tosa::TableOp op,
2102 PatternRewriter &rewriter) const final {
2103 auto loc = op.getLoc();
2104 Value input = op.getInput();
2105 Value table = op.getTable();
2106 auto inputTy = input.getType().cast<ShapedType>();
2107 auto tableTy = table.getType().cast<ShapedType>();
2108 auto resultTy = op.getType().cast<ShapedType>();
2109
2110 auto inputElementTy = inputTy.getElementType();
2111 auto tableElementTy = tableTy.getElementType();
2112 auto resultElementTy = resultTy.getElementType();
2113
2114 SmallVector<Value> dynDims;
2115 for (int i = 0; i < resultTy.getRank(); ++i) {
2116 if (inputTy.isDynamicDim(i)) {
2117 dynDims.push_back(
2118 rewriter.create<tensor::DimOp>(loc, op.getOperand(0), i));
2119 }
2120 }
2121
2122 auto initTensor =
2123 rewriter
2124 .create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
2125 resultElementTy)
2126 .result();
2127
2128 SmallVector<AffineMap, 2> affineMaps = {
2129 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2130 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2131
2132 auto genericOp = rewriter.create<linalg::GenericOp>(
2133 loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
2134 getNParallelLoopsAttrs(resultTy.getRank()));
2135 rewriter.replaceOp(op, genericOp.getResult(0));
2136
2137 {
2138 OpBuilder::InsertionGuard regionGuard(rewriter);
2139 Block *block = rewriter.createBlock(
2140 &genericOp.region(), genericOp.region().end(),
2141 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2142
2143 auto inputValue = block->getArgument(0);
2144 rewriter.setInsertionPointToStart(block);
2145 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2146 resultElementTy.isInteger(8)) {
2147 Value index = rewriter.create<arith::IndexCastOp>(
2148 loc, rewriter.getIndexType(), inputValue);
2149 Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128);
2150 index = rewriter.create<arith::AddIOp>(loc, rewriter.getIndexType(),
2151 index, offset);
2152 Value extract =
2153 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2154 rewriter.create<linalg::YieldOp>(loc, extract);
2155 return success();
2156 }
2157
2158 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2159 resultElementTy.isInteger(32)) {
2160 Value extend = rewriter.create<arith::ExtSIOp>(
2161 loc, rewriter.getI32Type(), inputValue);
2162
2163 auto offset = rewriter.create<arith::ConstantOp>(
2164 loc, rewriter.getI32IntegerAttr(32768));
2165 auto seven = rewriter.create<arith::ConstantOp>(
2166 loc, rewriter.getI32IntegerAttr(7));
2167 auto one = rewriter.create<arith::ConstantOp>(
2168 loc, rewriter.getI32IntegerAttr(1));
2169 auto b1111111 = rewriter.create<arith::ConstantOp>(
2170 loc, rewriter.getI32IntegerAttr(127));
2171
2172 // Compute the index and fractional part from the input value:
2173 // value = value + 32768
2174 // index = value >> 7;
2175 // fraction = 0x01111111 & value
2176 auto extendAdd = rewriter.create<arith::AddIOp>(loc, extend, offset);
2177 Value index = rewriter.create<arith::ShRUIOp>(loc, extendAdd, seven);
2178 Value fraction =
2179 rewriter.create<arith::AndIOp>(loc, extendAdd, b1111111);
2180
2181 // Extract the base and next values from the table.
2182 // base = (int32_t) table[index];
2183 // next = (int32_t) table[index + 1];
2184 Value indexPlusOne = rewriter.create<arith::AddIOp>(loc, index, one);
2185
2186 index = rewriter.create<arith::IndexCastOp>(
2187 loc, rewriter.getIndexType(), index);
2188 indexPlusOne = rewriter.create<arith::IndexCastOp>(
2189 loc, rewriter.getIndexType(), indexPlusOne);
2190
2191 Value base =
2192 rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
2193 Value next = rewriter.create<tensor::ExtractOp>(
2194 loc, table, ValueRange{indexPlusOne});
2195
2196 base =
2197 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), base);
2198 next =
2199 rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), next);
2200
2201 // Use the fractional part to interpolate between the input values:
2202 // result = (base << 7) + (next - base) * fraction
2203 Value baseScaled = rewriter.create<arith::ShLIOp>(loc, base, seven);
2204 Value diff = rewriter.create<arith::SubIOp>(loc, next, base);
2205 Value diffScaled = rewriter.create<arith::MulIOp>(loc, diff, fraction);
2206 Value result =
2207 rewriter.create<arith::AddIOp>(loc, baseScaled, diffScaled);
2208
2209 rewriter.create<linalg::YieldOp>(loc, result);
2210
2211 return success();
2212 }
2213 }
2214
2215 return rewriter.notifyMatchFailure(
2216 op, "unable to create body for tosa.table op");
2217 }
2218 };
2219
2220 } // namespace
2221
populateTosaToLinalgConversionPatterns(RewritePatternSet * patterns)2222 void mlir::tosa::populateTosaToLinalgConversionPatterns(
2223 RewritePatternSet *patterns) {
2224 patterns->add<
2225 // clang-format off
2226 PointwiseConverter<tosa::AddOp>,
2227 PointwiseConverter<tosa::SubOp>,
2228 PointwiseConverter<tosa::MulOp>,
2229 PointwiseConverter<tosa::DivOp>,
2230 PointwiseConverter<tosa::NegateOp>,
2231 PointwiseConverter<tosa::PowOp>,
2232 PointwiseConverter<tosa::ReciprocalOp>,
2233 PointwiseConverter<tosa::RsqrtOp>,
2234 PointwiseConverter<tosa::LogOp>,
2235 PointwiseConverter<tosa::ExpOp>,
2236 PointwiseConverter<tosa::AbsOp>,
2237 PointwiseConverter<tosa::TanhOp>,
2238 PointwiseConverter<tosa::BitwiseAndOp>,
2239 PointwiseConverter<tosa::BitwiseOrOp>,
2240 PointwiseConverter<tosa::BitwiseNotOp>,
2241 PointwiseConverter<tosa::BitwiseXorOp>,
2242 PointwiseConverter<tosa::LogicalAndOp>,
2243 PointwiseConverter<tosa::LogicalNotOp>,
2244 PointwiseConverter<tosa::LogicalOrOp>,
2245 PointwiseConverter<tosa::LogicalXorOp>,
2246 PointwiseConverter<tosa::CastOp>,
2247 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2248 PointwiseConverter<tosa::LogicalRightShiftOp>,
2249 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2250 PointwiseConverter<tosa::ClzOp>,
2251 PointwiseConverter<tosa::SelectOp>,
2252 PointwiseConverter<tosa::GreaterOp>,
2253 PointwiseConverter<tosa::GreaterEqualOp>,
2254 PointwiseConverter<tosa::EqualOp>,
2255 PointwiseConverter<tosa::MaximumOp>,
2256 PointwiseConverter<tosa::MinimumOp>,
2257 PointwiseConverter<tosa::CeilOp>,
2258 PointwiseConverter<tosa::FloorOp>,
2259 PointwiseConverter<tosa::ClampOp>,
2260 PointwiseConverter<tosa::ReluNOp>,
2261 PointwiseConverter<tosa::SigmoidOp>,
2262 IdentityNConverter<tosa::IdentityOp>,
2263 ReduceConverter<tosa::ReduceAllOp>,
2264 ReduceConverter<tosa::ReduceAnyOp>,
2265 ReduceConverter<tosa::ReduceMinOp>,
2266 ReduceConverter<tosa::ReduceMaxOp>,
2267 ReduceConverter<tosa::ReduceSumOp>,
2268 ReduceConverter<tosa::ReduceProdOp>,
2269 ArgMaxConverter,
2270 ConcatConverter,
2271 GatherConverter,
2272 PadConverter,
2273 ReshapeConverterCollapse,
2274 ReshapeConverterExpand,
2275 ReshapeConverterCollapseExpand,
2276 RescaleConverter,
2277 ResizeConverter,
2278 ReverseConverter,
2279 TableConverter,
2280 TileConverter,
2281 TransposeConverter>(patterns->getContext());
2282 // clang-format on
2283 }
2284