1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Utils/Utils.h"
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21 #include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
22 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27
28 #include <numeric>
29
30 using namespace mlir;
31 using namespace mlir::tosa;
32
applyPad(Location loc,Value input,ArrayRef<int64_t> pad,Attribute padAttr,OpBuilder & rewriter)33 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
34 Attribute padAttr, OpBuilder &rewriter) {
35 // Input should be padded if necessary.
36 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
37 return input;
38
39 ShapedType inputTy = input.getType().cast<ShapedType>();
40 Type inputETy = inputTy.getElementType();
41 auto inputShape = inputTy.getShape();
42
43 assert((inputShape.size() * 2) == pad.size());
44
45 SmallVector<int64_t, 4> paddedShape;
46 SmallVector<OpFoldResult, 8> lowIndices;
47 SmallVector<OpFoldResult, 8> highIndices;
48 for (int i = 0, s = inputShape.size(); i < s; i++) {
49 auto lowPad = pad[i * 2];
50 auto highPad = pad[i * 2 + 1];
51 if (ShapedType::isDynamic(inputShape[i]))
52 paddedShape.push_back(inputShape[i]);
53 else
54 paddedShape.push_back(inputShape[i] + highPad + lowPad);
55 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
56 highIndices.push_back(rewriter.getIndexAttr(highPad));
57 }
58
59 Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
60
61 return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
62 input, padValue, lowIndices, highIndices,
63 /*nofold=*/false, loc, rewriter)
64 .getResult();
65 }
66
reifyConstantDim(Attribute attr,ImplicitLocOpBuilder & builder)67 static mlir::Value reifyConstantDim(Attribute attr,
68 ImplicitLocOpBuilder &builder) {
69 return builder.createOrFold<arith::IndexCastOp>(
70 builder.getIndexType(), builder.create<arith::ConstantOp>(attr));
71 }
72
73 // Calculating the output width/height using the formula:
74 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
75 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
76 static mlir::Value
getConvOutputDim(Location loc,Value initDim,Attribute padBeforeAttr,Attribute padAfterAttr,Value kernelDim,Attribute strideAttr,Attribute dilationAttr,Type inputETy,OpBuilder & rewriter)77 getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
78 Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
79 Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
80 ImplicitLocOpBuilder builder(loc, rewriter);
81 auto one = rewriter.create<arith::ConstantOp>(
82 loc, IntegerAttr::get(initDim.getType(), 1));
83 Value padBefore = reifyConstantDim(padBeforeAttr, builder);
84 Value paddedBefore = builder.create<arith::AddIOp>(initDim, padBefore);
85 Value padAfter = reifyConstantDim(padAfterAttr, builder);
86 Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
87
88 Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
89 Value dilation = reifyConstantDim(dilationAttr, builder);
90 Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
91 Value addOne = builder.create<arith::AddIOp>(dilated, one);
92
93 Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
94 Value stride = reifyConstantDim(strideAttr, builder);
95 Value divide = builder.create<arith::DivUIOp>(subtract, stride);
96 return builder.create<arith::SubIOp>(divide, one);
97 }
98
99 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
inferDynamicDimsForConv(Location loc,Value input,Value weight,ShapedType resultTy,ArrayAttr padAttr,ArrayAttr strideAttr,ArrayAttr dilationAttr,int64_t weightHDim,int64_t weightWDim,OpBuilder & rewriter)100 static SmallVector<Value> inferDynamicDimsForConv(
101 Location loc, Value input, Value weight, ShapedType resultTy,
102 ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr,
103 int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) {
104 ShapedType inputTy = input.getType().cast<ShapedType>();
105 Type inputETy = inputTy.getElementType();
106 int64_t inputRank = inputTy.getRank();
107 int64_t heightDim = 1;
108 int64_t weightDim = 2;
109
110 SmallVector<Value> dynDims;
111 dynDims.resize(resultTy.getRank());
112 for (int i = 0; i < inputRank; i++) {
113 if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim)
114 dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
115 }
116
117 // Dynamic input height
118 if (inputTy.isDynamicDim(heightDim)) {
119 Value initHDim =
120 rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
121 Value kernelHDim =
122 rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
123 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
124 dynDims[heightDim] = getConvOutputDim(
125 loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim,
126 strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy,
127 rewriter);
128 }
129
130 // Dynamic input weight
131 if (inputTy.isDynamicDim(weightDim)) {
132 Value initWDim =
133 rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
134 Value kernelWDim =
135 rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
136 // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
137 dynDims[weightDim] = getConvOutputDim(
138 loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim,
139 strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy,
140 rewriter);
141 }
142
143 SmallVector<Value> filteredDims = condenseValues(dynDims);
144 return filteredDims;
145 }
146
147 // Creates a map to collapse the last dimension of the Depthwise convolution op
148 // due to a shape mismatch
createDepthwiseConvCollapseMap(int64_t outputRank,SmallVector<ReassociationExprs,4> & reassociationMap,OpBuilder & rewriter)149 static void createDepthwiseConvCollapseMap(
150 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
151 OpBuilder &rewriter) {
152 reassociationMap.resize(outputRank);
153 for (int i = 0; i < outputRank; i++) {
154 reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
155 }
156 reassociationMap[outputRank - 1].push_back(
157 rewriter.getAffineDimExpr(outputRank));
158 }
159
160 namespace {
161
162 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
163 public:
164 using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
165 LogicalResult
matchAndRewrite(tosa::Conv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const166 matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
167 ConversionPatternRewriter &rewriter) const final {
168 Location loc = op->getLoc();
169 Value input = op->getOperand(0);
170 Value weight = op->getOperand(1);
171 Value bias = op->getOperand(2);
172
173 ShapedType inputTy = input.getType().cast<ShapedType>();
174 ShapedType weightTy = weight.getType().cast<ShapedType>();
175 ShapedType biasTy = bias.getType().cast<ShapedType>();
176 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
177
178 Type inputETy = inputTy.getElementType();
179 Type resultETy = resultTy.getElementType();
180
181 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
182 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
183 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
184 bool isQuantized = op->hasAttr("quantization_info");
185
186 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
187 return rewriter.notifyMatchFailure(
188 op, "tosa.conv ops require static shapes for weight and bias");
189
190 if (inputETy.isUnsignedInteger())
191 return rewriter.notifyMatchFailure(
192 op, "tosa.conv ops does not support unsigned integer input");
193
194 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
195 loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
196 /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
197
198 auto weightShape = weightTy.getShape();
199
200 // Apply padding as necessary.
201 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
202 if (isQuantized) {
203 auto quantizationInfo =
204 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
205 int64_t iZp = quantizationInfo.getInputZp();
206
207 int64_t intMin =
208 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
209 .getSExtValue();
210 int64_t intMax =
211 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
212 .getSExtValue();
213
214 if (iZp < intMin || iZp > intMax)
215 return rewriter.notifyMatchFailure(
216 op, "tosa.conv op quantization has zp outside of input range");
217
218 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
219 }
220
221 llvm::SmallVector<int64_t> pad;
222 pad.resize(2, 0);
223 getValuesFromIntArrayAttribute(padAttr, pad);
224 pad.resize(pad.size() + 2, 0);
225 input = applyPad(loc, input, pad, zeroAttr, rewriter);
226
227 // Transpose the kernel to match dimension ordering of the linalg
228 // convolution operation.
229 // TODO(suderman): See if this can be efficiently folded - check whether
230 // the input is used anywhere else, if not fold the constant.
231 SmallVector<int64_t> weightPerm{1, 2, 3, 0};
232 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
233 weightShape[3], weightShape[0]};
234 auto weightPermAttr = DenseIntElementsAttr::get(
235 RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
236 Value weightPermValue =
237 rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
238 Type newWeightTy =
239 RankedTensorType::get(newWeightShape, weightTy.getElementType());
240 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
241 weightPermValue);
242
243 Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
244 Value initTensor = rewriter.create<linalg::InitTensorOp>(
245 loc, filteredDims, resultTy.getShape(), resultETy);
246 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
247 Value zeroTensor = rewriter
248 .create<linalg::FillOp>(loc, ValueRange{zero},
249 ValueRange{initTensor})
250 .result();
251
252 // Extract the attributes for convolution.
253 llvm::SmallVector<int64_t> stride, dilation;
254 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
255 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
256
257 // Create the convolution op.
258 auto strideAttr = DenseIntElementsAttr::get(
259 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
260 auto dilationAttr = DenseIntElementsAttr::get(
261 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
262
263 // Create maps for the bias broadcasting
264 SmallVector<AffineMap, 4> indexingMaps;
265 indexingMaps.push_back(AffineMap::get(
266 /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
267 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
268 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
269 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
270
271 Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
272 loc, filteredDims, resultTy.getShape(), resultETy);
273
274 if (isQuantized) {
275 auto quantizationInfo =
276 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
277 auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
278 auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
279
280 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
281 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
282 Value conv =
283 rewriter
284 .create<linalg::Conv2DNhwcHwcfQOp>(
285 loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
286 ValueRange{zeroTensor}, strideAttr, dilationAttr)
287 ->getResult(0);
288
289 Value result =
290 rewriter
291 .create<linalg::GenericOp>(
292 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
293 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
294 [&](OpBuilder &nestedBuilder, Location nestedLoc,
295 ValueRange args) {
296 Value added = nestedBuilder.create<arith::AddIOp>(
297 loc, args[0], args[1]);
298 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
299 })
300 .getResult(0);
301 rewriter.replaceOp(op, result);
302 return success();
303 }
304
305 Value conv = rewriter
306 .create<linalg::Conv2DNhwcHwcfOp>(
307 loc, resultTy, ValueRange{input, weight},
308 ValueRange{zeroTensor}, strideAttr, dilationAttr)
309 ->getResult(0);
310
311 Value result =
312 rewriter
313 .create<linalg::GenericOp>(
314 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
315 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
316 [&](OpBuilder &nestedBuilder, Location nestedLoc,
317 ValueRange args) {
318 Value added = nestedBuilder.create<arith::AddFOp>(
319 loc, args[0], args[1]);
320 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
321 })
322 .getResult(0);
323
324 rewriter.replaceOp(op, result);
325 return success();
326 }
327 };
328
329 class DepthwiseConvConverter
330 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
331 public:
332 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
333 LogicalResult
matchAndRewrite(tosa::DepthwiseConv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const334 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const final {
336 Location loc = op->getLoc();
337 Value input = op->getOperand(0);
338 Value weight = op->getOperand(1);
339 Value bias = op->getOperand(2);
340
341 ShapedType inputTy = input.getType().cast<ShapedType>();
342 ShapedType weightTy = weight.getType().cast<ShapedType>();
343 ShapedType biasTy = bias.getType().cast<ShapedType>();
344 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
345 int64_t resultRank = resultTy.getRank();
346
347 Type inputETy = inputTy.getElementType();
348 Type resultETy = resultTy.getElementType();
349
350 auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
351 auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
352 auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
353
354 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
355 return rewriter.notifyMatchFailure(
356 op, "tosa.depthwise_conv ops require static shapes");
357
358 // Compute output dynamic dims
359 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
360 loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
361 0, 1, rewriter);
362
363 bool isQuantized = op->hasAttr("quantization_info");
364 IntegerAttr iZp;
365 IntegerAttr kZp;
366 if (isQuantized) {
367 auto quantizationInfo =
368 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
369 iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
370 kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
371 }
372
373 auto weightShape = weightTy.getShape();
374 auto resultShape = resultTy.getShape();
375
376 // Apply padding as necessary.
377 Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
378 if (isQuantized) {
379 auto quantizationInfo =
380 op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
381 int64_t iZp = quantizationInfo.getInputZp();
382
383 int64_t intMin =
384 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
385 .getSExtValue();
386 int64_t intMax =
387 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
388 .getSExtValue();
389
390 if (iZp < intMin || iZp > intMax)
391 return rewriter.notifyMatchFailure(
392 op, "tosa.depthwise_conv op quantization has zp outside of input "
393 "range");
394
395 zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
396 }
397
398 llvm::SmallVector<int64_t> pad;
399 pad.resize(2, 0);
400 getValuesFromIntArrayAttribute(padAttr, pad);
401 pad.resize(pad.size() + 2, 0);
402
403 input = applyPad(loc, input, pad, zeroAttr, rewriter);
404
405 // Extract the attributes for convolution.
406 llvm::SmallVector<int64_t> stride, dilation;
407 getValuesFromIntArrayAttribute(strideTosaAttr, stride);
408 getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
409
410 // Create the convolution op.
411 auto strideAttr = DenseIntElementsAttr::get(
412 RankedTensorType::get({2}, rewriter.getI64Type()), stride);
413 auto dilationAttr = DenseIntElementsAttr::get(
414 RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
415 ShapedType linalgConvTy =
416 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
417 weightShape[2], weightShape[3]},
418 resultETy);
419
420 // Broadcast the initial value to the output tensor before convolving.
421 SmallVector<AffineMap, 4> indexingMaps;
422 indexingMaps.push_back(AffineMap::get(
423 /*dimCount=*/resultRank, /*symbolCount=*/0,
424 {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
425 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
426 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
427
428 Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
429 Value initTensor = rewriter.create<linalg::InitTensorOp>(
430 loc, filteredDims, linalgConvTy.getShape(), resultETy);
431 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
432 Value zeroTensor = rewriter
433 .create<linalg::FillOp>(loc, ValueRange{zero},
434 ValueRange{initTensor})
435 .result();
436
437 Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
438 loc, filteredDims, resultTy.getShape(), resultETy);
439 if (!isQuantized) {
440 Value conv = rewriter
441 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
442 loc, linalgConvTy, ValueRange{input, weight},
443 ValueRange{zeroTensor}, strideAttr, dilationAttr)
444 .getResult(0);
445
446 SmallVector<ReassociationExprs, 4> reassociationMap;
447 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
448 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
449 loc, resultTy, conv, reassociationMap);
450
451 Value result =
452 rewriter
453 .create<linalg::GenericOp>(
454 loc, resultTy, ValueRange({bias, convReshape}),
455 biasInitTensor, indexingMaps,
456 getNParallelLoopsAttrs(resultRank),
457 [&](OpBuilder &nestedBuilder, Location nestedLoc,
458 ValueRange args) {
459 Value added = nestedBuilder.create<arith::AddFOp>(
460 loc, args[0], args[1]);
461 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
462 })
463 .getResult(0);
464 rewriter.replaceOp(op, result);
465 } else {
466 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
467 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
468 Value conv =
469 rewriter
470 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
471 loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
472 ValueRange{zeroTensor}, strideAttr, dilationAttr)
473 .getResult(0);
474 SmallVector<ReassociationExprs, 4> reassociationMap;
475 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
476 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
477 loc, resultTy, conv, reassociationMap);
478 Value result =
479 rewriter
480 .create<linalg::GenericOp>(
481 loc, resultTy, ValueRange({bias, convReshape}),
482 biasInitTensor, indexingMaps,
483 getNParallelLoopsAttrs(resultRank),
484 [&](OpBuilder &nestedBuilder, Location nestedLoc,
485 ValueRange args) {
486 Value added = nestedBuilder.create<arith::AddIOp>(
487 loc, args[0], args[1]);
488 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
489 })
490 .getResult(0);
491 rewriter.replaceOp(op, result);
492 }
493 return success();
494 }
495 };
496
497 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
498 public:
499 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
500 LogicalResult
matchAndRewrite(tosa::MatMulOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const501 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
502 ConversionPatternRewriter &rewriter) const final {
503 Location loc = op.getLoc();
504
505 auto outputTy = op.getType().cast<ShapedType>();
506 auto outputElementTy = outputTy.getElementType();
507
508 auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
509 auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
510
511 SmallVector<Value> dynDims;
512 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
513
514 if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
515 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
516 }
517
518 if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
519 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
520 }
521
522 if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
523 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
524 }
525
526 SmallVector<Value> filteredDims = condenseValues(dynDims);
527
528 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
529 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
530 auto initTensor = rewriter.create<linalg::InitTensorOp>(
531 loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
532 Value zeroTensor = rewriter
533 .create<linalg::FillOp>(loc, ValueRange{zero},
534 ValueRange{initTensor})
535 .result();
536 if (!op.getQuantizationInfo()) {
537 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
538 op, TypeRange{op.getType()},
539 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
540 return success();
541 }
542
543 auto quantizationInfo = *op.getQuantizationInfo();
544 auto aZp = rewriter.create<arith::ConstantOp>(
545 loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
546 auto bZp = rewriter.create<arith::ConstantOp>(
547 loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
548 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
549 op, TypeRange{op.getType()},
550 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
551
552 return success();
553 }
554 };
555
556 class FullyConnectedConverter
557 : public OpConversionPattern<tosa::FullyConnectedOp> {
558 public:
559 using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
560 LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const561 matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
562 ConversionPatternRewriter &rewriter) const final {
563 Location loc = op.getLoc();
564 auto outputTy = op.getType().cast<ShapedType>();
565 auto input = op.getInput();
566 auto inputTy = input.getType().cast<ShapedType>();
567
568 auto bias = op.getBias();
569
570 auto weight = op.getWeight();
571 auto weightTy = weight.getType().cast<ShapedType>();
572 auto weightShape = weightTy.getShape();
573
574 auto outputETy = outputTy.getElementType();
575
576 SmallVector<Value> dynDims;
577 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
578
579 if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
580 dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
581 }
582
583 if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
584 dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
585 }
586
587 SmallVector<Value> filteredDims = condenseValues(dynDims);
588
589 // Creating maps for the output of MatMul and the bias
590 SmallVector<AffineMap, 4> indexingMaps;
591
592 // Broadcast the bias.
593 indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
594 {rewriter.getAffineDimExpr(1)},
595 rewriter.getContext()));
596
597 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
598 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
599
600 auto initTensor = rewriter.create<linalg::InitTensorOp>(
601 loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
602
603 // When quantized, the input elemeny type is not the same as the output
604 Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
605 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
606 Value zeroTensor = rewriter
607 .create<linalg::FillOp>(loc, ValueRange{zero},
608 ValueRange{initTensor})
609 .result();
610
611 SmallVector<int64_t> permutation{1, 0};
612 auto permutationAttr = DenseIntElementsAttr::get(
613 RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
614 Value permutationValue =
615 rewriter.create<arith::ConstantOp>(loc, permutationAttr);
616
617 SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
618 Type newWeightTy =
619 RankedTensorType::get(newWeightShape, weightTy.getElementType());
620
621 Value transposedWeight = rewriter.create<tosa::TransposeOp>(
622 loc, newWeightTy, weight, permutationValue);
623
624 auto biasInitTensor =
625 rewriter
626 .create<linalg::InitTensorOp>(loc, filteredDims,
627 outputTy.getShape(), outputETy)
628 ->getResults();
629
630 if (!op.getQuantizationInfo()) {
631 Value matmul = rewriter
632 .create<linalg::MatmulOp>(
633 loc, TypeRange{op.getType()},
634 ValueRange{input, transposedWeight}, zeroTensor)
635 ->getResult(0);
636
637 Value result =
638 rewriter
639 .create<linalg::GenericOp>(
640 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
641 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
642 [&](OpBuilder &nestedBuilder, Location nestedLoc,
643 ValueRange args) {
644 Value added = nestedBuilder.create<arith::AddFOp>(
645 loc, args[0], args[1]);
646 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
647 })
648 .getResult(0);
649 rewriter.replaceOp(op, result);
650 return success();
651 }
652
653 auto quantizationInfo = *op.getQuantizationInfo();
654 auto inputZp = rewriter.create<arith::ConstantOp>(
655 loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
656 auto outputZp = rewriter.create<arith::ConstantOp>(
657 loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
658 Value matmul =
659 rewriter
660 .create<linalg::QuantizedMatmulOp>(
661 loc, TypeRange{op.getType()},
662 ValueRange{input, transposedWeight, inputZp, outputZp},
663 zeroTensor)
664 ->getResult(0);
665 Value result =
666 rewriter
667 .create<linalg::GenericOp>(
668 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
669 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
670 [&](OpBuilder &nestedBuilder, Location nestedLoc,
671 ValueRange args) {
672 Value added = nestedBuilder.create<arith::AddIOp>(
673 loc, args[0], args[1]);
674 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
675 })
676 .getResult(0);
677 rewriter.replaceOp(op, result);
678 return success();
679 }
680 };
681
682 class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
683 public:
684 using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
685
matchAndRewrite(tosa::MaxPool2dOp op,PatternRewriter & rewriter) const686 LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
687 PatternRewriter &rewriter) const final {
688 Location loc = op.getLoc();
689 Value input = op.getInput();
690 ShapedType inputTy = input.getType().cast<ShapedType>();
691
692 ShapedType resultTy = op.getType().template cast<ShapedType>();
693 Type resultETy = inputTy.getElementType();
694
695 auto dynamicDimsOr =
696 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
697 if (!dynamicDimsOr.has_value())
698 return failure();
699 SmallVector<Value> dynamicDims = dynamicDimsOr.value();
700
701 // Determine what the initial value needs to be for the max pool op.
702 Attribute initialAttr;
703 if (resultETy.isF32())
704 initialAttr = rewriter.getFloatAttr(
705 resultETy,
706 APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
707 true));
708
709 if (resultETy.isa<IntegerType>())
710 initialAttr = rewriter.getIntegerAttr(
711 resultETy,
712 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
713
714 if (!initialAttr)
715 return rewriter.notifyMatchFailure(
716 op, "Unsupported initial value for tosa.maxpool_2d op");
717
718 // Apply padding as necessary.
719 llvm::SmallVector<int64_t> pad;
720 pad.resize(2, 0);
721 getValuesFromIntArrayAttribute(op.getPad(), pad);
722 pad.resize(pad.size() + 2, 0);
723 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
724
725 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
726
727 SmallVector<int64_t> kernel, stride;
728 getValuesFromIntArrayAttribute(op.getKernel(), kernel);
729 getValuesFromIntArrayAttribute(op.getStride(), stride);
730
731 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
732 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
733
734 // Create the linalg op that performs pooling.
735 Value initTensor = rewriter.create<linalg::InitTensorOp>(
736 loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
737
738 Value filledInitTensor =
739 rewriter
740 .create<linalg::FillOp>(loc, ValueRange{initialValue},
741 ValueRange{initTensor})
742 .result();
743
744 Value fakeWindowDims =
745 rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
746
747 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
748 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
749 filledInitTensor, strideAttr, dilationAttr);
750 return success();
751 }
752 };
753
754 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
755 public:
756 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
757
matchAndRewrite(tosa::AvgPool2dOp op,PatternRewriter & rewriter) const758 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
759 PatternRewriter &rewriter) const final {
760 Location loc = op.getLoc();
761 Value input = op.getInput();
762 ShapedType inputTy = input.getType().cast<ShapedType>();
763 Type inElementTy = inputTy.getElementType();
764
765 ShapedType resultTy = op.getType().template cast<ShapedType>();
766 Type resultETy = op.getType().cast<ShapedType>().getElementType();
767
768 Type accETy =
769 inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
770 ShapedType accTy = resultTy.clone(accETy);
771
772 auto dynamicDimsOr =
773 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
774 if (!dynamicDimsOr.has_value())
775 return failure();
776 SmallVector<Value> dynamicDims = dynamicDimsOr.value();
777
778 // Apply padding as necessary.
779 llvm::SmallVector<int64_t> pad;
780 pad.resize(2, 0);
781 getValuesFromIntArrayAttribute(op.getPad(), pad);
782 pad.resize(pad.size() + 2, 0);
783 Attribute padAttr = rewriter.getZeroAttr(inElementTy);
784 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
785
786 Attribute initialAttr = rewriter.getZeroAttr(accETy);
787 Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
788
789 SmallVector<int64_t> kernel, stride;
790 getValuesFromIntArrayAttribute(op.getKernel(), kernel);
791 getValuesFromIntArrayAttribute(op.getStride(), stride);
792
793 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
794 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
795
796 // Create the linalg op that performs pooling.
797 Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
798 loc, dynamicDims, accTy.getShape(), accETy);
799
800 Value filledInitTensor =
801 rewriter
802 .create<linalg::FillOp>(loc, ValueRange{initialValue},
803 ValueRange{poolInitTensor})
804 .result();
805
806 Value fakeWindowDims =
807 rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
808
809 // Sum across the pooled region.
810 Value poolingOp = rewriter
811 .create<linalg::PoolingNhwcSumOp>(
812 loc, ArrayRef<Type>{accTy},
813 ValueRange{paddedInput, fakeWindowDims},
814 filledInitTensor, strideAttr, dilationAttr)
815 .getResult(0);
816
817 // Normalize the summed value by the number of elements grouped in each
818 // pool.
819 auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
820 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
821
822 Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
823 loc, dynamicDims, resultTy.getShape(), resultETy);
824
825 auto genericOp = rewriter.create<linalg::GenericOp>(
826 loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
827 ValueRange{genericInitTensor},
828 ArrayRef<AffineMap>({affineMap, affineMap}),
829 getNParallelLoopsAttrs(resultTy.getRank()),
830 [&](OpBuilder &b, Location loc, ValueRange args) {
831 auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
832 auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
833 auto iH = rewriter.create<arith::ConstantIndexOp>(
834 loc, poolingOpTy.getDimSize(1) - 1);
835 auto iW = rewriter.create<arith::ConstantIndexOp>(
836 loc, poolingOpTy.getDimSize(2) - 1);
837
838 // Compute the indices from either end.
839 auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
840 auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
841 auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
842 auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
843
844 // Determines what the portion of valid input is covered by the
845 // kernel.
846 auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
847 if (pad == 0)
848 return v;
849
850 auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
851 Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
852
853 Value cmp = rewriter.create<arith::CmpIOp>(
854 loc, arith::CmpIPredicate::slt, dx, zero);
855 Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
856 return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
857 };
858
859 // Compute the vertical component of coverage.
860 auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
861 auto kH1 = padFn(kH0, y0, pad[2]);
862 auto kH2 = padFn(kH1, y1, pad[3]);
863 auto kHCmp = rewriter.create<arith::CmpIOp>(
864 loc, arith::CmpIPredicate::slt, kH2, one);
865 auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
866
867 // compute the horizontal component of coverage.
868 auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
869 auto kW1 = padFn(kW0, x0, pad[4]);
870 auto kW2 = padFn(kW1, x1, pad[5]);
871 auto kWCmp = rewriter.create<arith::CmpIOp>(
872 loc, arith::CmpIPredicate::slt, kW2, one);
873 auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
874
875 // Compute the total number of elements and normalize.
876 Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
877 auto countI = rewriter.create<arith::IndexCastOp>(
878 loc, rewriter.getI32Type(), count);
879
880 // Divide by the number of summed values. For floats this is just
881 // a div however for quantized values input normalization had
882 // to be applied.
883 Value poolVal = args[0];
884 if (accETy.isa<FloatType>()) {
885 auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
886 poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
887 ->getResult(0);
888 } else {
889
890 // If we have quantization information we need to apply an offset
891 // for the input zp value.
892 if (op.getQuantizationInfo()) {
893 auto quantizationInfo = *op.getQuantizationInfo();
894 auto inputZp = rewriter.create<arith::ConstantOp>(
895 loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
896 Value offset =
897 rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
898 poolVal =
899 rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
900 }
901
902 // Compute the multiplier and shift values for the quantization
903 // normalization. Preferably we would want to compute more bits
904 // however 32-bits should be enough for compute. Honestly we
905 // should probably straight divide.
906 int64_t numerator = ((1 << 30) + 1);
907 int64_t shift = 30;
908
909 Value numeratorVal = rewriter.create<arith::ConstantOp>(
910 loc, rewriter.getI32IntegerAttr(numerator));
911 Value multiplierVal =
912 rewriter
913 .create<arith::DivUIOp>(loc, rewriter.getI32Type(),
914 numeratorVal, countI)
915 .getResult();
916 Value shiftVal = rewriter.create<arith::ConstantOp>(
917 loc, rewriter.getI8IntegerAttr(shift));
918
919 auto scaled =
920 rewriter
921 .create<tosa::ApplyScaleOp>(
922 loc, rewriter.getI32Type(), poolVal, multiplierVal,
923 shiftVal, rewriter.getBoolAttr(false))
924 .getResult();
925
926 // If we have quantization information we need to apply output
927 // zeropoint.
928 if (op.getQuantizationInfo()) {
929 auto quantizationInfo = *op.getQuantizationInfo();
930 auto outputZp = rewriter.create<arith::ConstantOp>(
931 loc, b.getIntegerAttr(scaled.getType(),
932 quantizationInfo.getOutputZp()));
933 scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
934 .getResult();
935 }
936
937 // Apply Clip.
938 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
939
940 auto min = rewriter.create<arith::ConstantIntOp>(
941 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
942 accETy);
943 auto max = rewriter.create<arith::ConstantIntOp>(
944 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
945 accETy);
946 auto clamp = clampHelper<arith::CmpIOp>(
947 loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
948
949 poolVal = clamp;
950 // Convert type.
951 if (resultETy != clamp.getType()) {
952 poolVal =
953 rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
954 }
955 }
956
957 rewriter.create<linalg::YieldOp>(loc, poolVal);
958 });
959
960 rewriter.replaceOp(op, genericOp.getResult(0));
961 return success();
962 }
963 };
964
965 } // namespace
966
populateTosaToLinalgNamedConversionPatterns(RewritePatternSet * patterns)967 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
968 RewritePatternSet *patterns) {
969 patterns->add<
970 // clang-format off
971 ConvConverter,
972 DepthwiseConvConverter,
973 MatMulConverter,
974 MaxPool2dConverter,
975 AvgPool2dConverter,
976 FullyConnectedConverter>(patterns->getContext());
977 // clang-format on
978 }
979