1 //===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the Tosa to the Linalg named ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/Linalg/IR/Linalg.h"
16 #include "mlir/Dialect/Math/IR/Math.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/Tensor/IR/Tensor.h"
19 #include "mlir/Dialect/Tensor/Utils/Utils.h"
20 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
21 #include "mlir/Dialect/Tosa/Utils/CoversionUtils.h"
22 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Transforms/DialectConversion.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 
28 #include <numeric>
29 
30 using namespace mlir;
31 using namespace mlir::tosa;
32 
applyPad(Location loc,Value input,ArrayRef<int64_t> pad,Attribute padAttr,OpBuilder & rewriter)33 static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
34                             Attribute padAttr, OpBuilder &rewriter) {
35   // Input should be padded if necessary.
36   if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
37     return input;
38 
39   ShapedType inputTy = input.getType().cast<ShapedType>();
40   Type inputETy = inputTy.getElementType();
41   auto inputShape = inputTy.getShape();
42 
43   assert((inputShape.size() * 2) == pad.size());
44 
45   SmallVector<int64_t, 4> paddedShape;
46   SmallVector<OpFoldResult, 8> lowIndices;
47   SmallVector<OpFoldResult, 8> highIndices;
48   for (int i = 0, s = inputShape.size(); i < s; i++) {
49     auto lowPad = pad[i * 2];
50     auto highPad = pad[i * 2 + 1];
51     if (ShapedType::isDynamic(inputShape[i]))
52       paddedShape.push_back(inputShape[i]);
53     else
54       paddedShape.push_back(inputShape[i] + highPad + lowPad);
55     lowIndices.push_back(rewriter.getIndexAttr(lowPad));
56     highIndices.push_back(rewriter.getIndexAttr(highPad));
57   }
58 
59   Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
60 
61   return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
62                                    input, padValue, lowIndices, highIndices,
63                                    /*nofold=*/false, loc, rewriter)
64       .getResult();
65 }
66 
reifyConstantDim(Attribute attr,ImplicitLocOpBuilder & builder)67 static mlir::Value reifyConstantDim(Attribute attr,
68                                     ImplicitLocOpBuilder &builder) {
69   return builder.createOrFold<arith::IndexCastOp>(
70       builder.getIndexType(), builder.create<arith::ConstantOp>(attr));
71 }
72 
73 // Calculating the output width/height using the formula:
74 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
75 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
76 static mlir::Value
getConvOutputDim(Location loc,Value initDim,Attribute padBeforeAttr,Attribute padAfterAttr,Value kernelDim,Attribute strideAttr,Attribute dilationAttr,Type inputETy,OpBuilder & rewriter)77 getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr,
78                  Attribute padAfterAttr, Value kernelDim, Attribute strideAttr,
79                  Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) {
80   ImplicitLocOpBuilder builder(loc, rewriter);
81   auto one = rewriter.create<arith::ConstantOp>(
82       loc, IntegerAttr::get(initDim.getType(), 1));
83   Value padBefore = reifyConstantDim(padBeforeAttr, builder);
84   Value paddedBefore = builder.create<arith::AddIOp>(initDim, padBefore);
85   Value padAfter = reifyConstantDim(padAfterAttr, builder);
86   Value paddedAfter = builder.create<arith::AddIOp>(paddedBefore, padAfter);
87 
88   Value subOne = builder.create<arith::SubIOp>(kernelDim, one);
89   Value dilation = reifyConstantDim(dilationAttr, builder);
90   Value dilated = builder.create<arith::MulIOp>(dilation, subOne);
91   Value addOne = builder.create<arith::AddIOp>(dilated, one);
92 
93   Value subtract = builder.create<arith::SubIOp>(paddedAfter, addOne);
94   Value stride = reifyConstantDim(strideAttr, builder);
95   Value divide = builder.create<arith::DivUIOp>(subtract, stride);
96   return builder.create<arith::SubIOp>(divide, one);
97 }
98 
99 // Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
inferDynamicDimsForConv(Location loc,Value input,Value weight,ShapedType resultTy,ArrayAttr padAttr,ArrayAttr strideAttr,ArrayAttr dilationAttr,int64_t weightHDim,int64_t weightWDim,OpBuilder & rewriter)100 static SmallVector<Value> inferDynamicDimsForConv(
101     Location loc, Value input, Value weight, ShapedType resultTy,
102     ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr,
103     int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) {
104   ShapedType inputTy = input.getType().cast<ShapedType>();
105   Type inputETy = inputTy.getElementType();
106   int64_t inputRank = inputTy.getRank();
107   int64_t heightDim = 1;
108   int64_t weightDim = 2;
109 
110   SmallVector<Value> dynDims;
111   dynDims.resize(resultTy.getRank());
112   for (int i = 0; i < inputRank; i++) {
113     if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim)
114       dynDims[i] = rewriter.create<tensor::DimOp>(loc, input, i);
115   }
116 
117   // Dynamic input height
118   if (inputTy.isDynamicDim(heightDim)) {
119     Value initHDim =
120         rewriter.create<tensor::DimOp>(loc, input, heightDim).getResult();
121     Value kernelHDim =
122         rewriter.create<tensor::DimOp>(loc, weight, weightHDim).getResult();
123     // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
124     dynDims[heightDim] = getConvOutputDim(
125         loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim,
126         strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy,
127         rewriter);
128   }
129 
130   // Dynamic input weight
131   if (inputTy.isDynamicDim(weightDim)) {
132     Value initWDim =
133         rewriter.create<tensor::DimOp>(loc, input, weightDim).getResult();
134     Value kernelWDim =
135         rewriter.create<tensor::DimOp>(loc, weight, weightWDim).getResult();
136     // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x)
137     dynDims[weightDim] = getConvOutputDim(
138         loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim,
139         strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy,
140         rewriter);
141   }
142 
143   SmallVector<Value> filteredDims = condenseValues(dynDims);
144   return filteredDims;
145 }
146 
147 // Creates a map to collapse the last dimension of the Depthwise convolution op
148 // due to a shape mismatch
createDepthwiseConvCollapseMap(int64_t outputRank,SmallVector<ReassociationExprs,4> & reassociationMap,OpBuilder & rewriter)149 static void createDepthwiseConvCollapseMap(
150     int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
151     OpBuilder &rewriter) {
152   reassociationMap.resize(outputRank);
153   for (int i = 0; i < outputRank; i++) {
154     reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
155   }
156   reassociationMap[outputRank - 1].push_back(
157       rewriter.getAffineDimExpr(outputRank));
158 }
159 
160 namespace {
161 
162 class ConvConverter : public OpConversionPattern<tosa::Conv2DOp> {
163 public:
164   using OpConversionPattern<tosa::Conv2DOp>::OpConversionPattern;
165   LogicalResult
matchAndRewrite(tosa::Conv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const166   matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor,
167                   ConversionPatternRewriter &rewriter) const final {
168     Location loc = op->getLoc();
169     Value input = op->getOperand(0);
170     Value weight = op->getOperand(1);
171     Value bias = op->getOperand(2);
172 
173     ShapedType inputTy = input.getType().cast<ShapedType>();
174     ShapedType weightTy = weight.getType().cast<ShapedType>();
175     ShapedType biasTy = bias.getType().cast<ShapedType>();
176     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
177 
178     Type inputETy = inputTy.getElementType();
179     Type resultETy = resultTy.getElementType();
180 
181     auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
182     auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
183     auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
184     bool isQuantized = op->hasAttr("quantization_info");
185 
186     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
187       return rewriter.notifyMatchFailure(
188           op, "tosa.conv ops require static shapes for weight and bias");
189 
190     if (inputETy.isUnsignedInteger())
191       return rewriter.notifyMatchFailure(
192           op, "tosa.conv ops does not support unsigned integer input");
193 
194     SmallVector<Value> filteredDims = inferDynamicDimsForConv(
195         loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
196         /*weightHDim=*/1, /*weightWDim=*/2, rewriter);
197 
198     auto weightShape = weightTy.getShape();
199 
200     // Apply padding as necessary.
201     Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
202     if (isQuantized) {
203       auto quantizationInfo =
204           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
205       int64_t iZp = quantizationInfo.getInputZp();
206 
207       int64_t intMin =
208           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
209               .getSExtValue();
210       int64_t intMax =
211           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
212               .getSExtValue();
213 
214       if (iZp < intMin || iZp > intMax)
215         return rewriter.notifyMatchFailure(
216             op, "tosa.conv op quantization has zp outside of input range");
217 
218       zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
219     }
220 
221     llvm::SmallVector<int64_t> pad;
222     pad.resize(2, 0);
223     getValuesFromIntArrayAttribute(padAttr, pad);
224     pad.resize(pad.size() + 2, 0);
225     input = applyPad(loc, input, pad, zeroAttr, rewriter);
226 
227     // Transpose the kernel to match dimension ordering of the linalg
228     // convolution operation.
229     // TODO(suderman): See if this can be efficiently folded - check whether
230     // the input is used anywhere else, if not fold the constant.
231     SmallVector<int64_t> weightPerm{1, 2, 3, 0};
232     SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[2],
233                                         weightShape[3], weightShape[0]};
234     auto weightPermAttr = DenseIntElementsAttr::get(
235         RankedTensorType::get({4}, rewriter.getI64Type()), weightPerm);
236     Value weightPermValue =
237         rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
238     Type newWeightTy =
239         RankedTensorType::get(newWeightShape, weightTy.getElementType());
240     weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
241                                                 weightPermValue);
242 
243     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
244     Value initTensor = rewriter.create<linalg::InitTensorOp>(
245         loc, filteredDims, resultTy.getShape(), resultETy);
246     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
247     Value zeroTensor = rewriter
248                            .create<linalg::FillOp>(loc, ValueRange{zero},
249                                                    ValueRange{initTensor})
250                            .result();
251 
252     // Extract the attributes for convolution.
253     llvm::SmallVector<int64_t> stride, dilation;
254     getValuesFromIntArrayAttribute(strideTosaAttr, stride);
255     getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
256 
257     // Create the convolution op.
258     auto strideAttr = DenseIntElementsAttr::get(
259         RankedTensorType::get({2}, rewriter.getI64Type()), stride);
260     auto dilationAttr = DenseIntElementsAttr::get(
261         RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
262 
263     // Create maps for the bias broadcasting
264     SmallVector<AffineMap, 4> indexingMaps;
265     indexingMaps.push_back(AffineMap::get(
266         /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
267         {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
268     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
269     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
270 
271     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
272         loc, filteredDims, resultTy.getShape(), resultETy);
273 
274     if (isQuantized) {
275       auto quantizationInfo =
276           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
277       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
278       auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
279 
280       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
281       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
282       Value conv =
283           rewriter
284               .create<linalg::Conv2DNhwcHwcfQOp>(
285                   loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
286                   ValueRange{zeroTensor}, strideAttr, dilationAttr)
287               ->getResult(0);
288 
289       Value result =
290           rewriter
291               .create<linalg::GenericOp>(
292                   loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
293                   indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
294                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
295                       ValueRange args) {
296                     Value added = nestedBuilder.create<arith::AddIOp>(
297                         loc, args[0], args[1]);
298                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
299                   })
300               .getResult(0);
301       rewriter.replaceOp(op, result);
302       return success();
303     }
304 
305     Value conv = rewriter
306                      .create<linalg::Conv2DNhwcHwcfOp>(
307                          loc, resultTy, ValueRange{input, weight},
308                          ValueRange{zeroTensor}, strideAttr, dilationAttr)
309                      ->getResult(0);
310 
311     Value result =
312         rewriter
313             .create<linalg::GenericOp>(
314                 loc, resultTy, ValueRange({bias, conv}), biasInitTensor,
315                 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
316                 [&](OpBuilder &nestedBuilder, Location nestedLoc,
317                     ValueRange args) {
318                   Value added = nestedBuilder.create<arith::AddFOp>(
319                       loc, args[0], args[1]);
320                   nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
321                 })
322             .getResult(0);
323 
324     rewriter.replaceOp(op, result);
325     return success();
326   }
327 };
328 
329 class DepthwiseConvConverter
330     : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
331 public:
332   using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
333   LogicalResult
matchAndRewrite(tosa::DepthwiseConv2DOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const334   matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
335                   ConversionPatternRewriter &rewriter) const final {
336     Location loc = op->getLoc();
337     Value input = op->getOperand(0);
338     Value weight = op->getOperand(1);
339     Value bias = op->getOperand(2);
340 
341     ShapedType inputTy = input.getType().cast<ShapedType>();
342     ShapedType weightTy = weight.getType().cast<ShapedType>();
343     ShapedType biasTy = bias.getType().cast<ShapedType>();
344     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
345     int64_t resultRank = resultTy.getRank();
346 
347     Type inputETy = inputTy.getElementType();
348     Type resultETy = resultTy.getElementType();
349 
350     auto padAttr = op->getAttr("pad").cast<ArrayAttr>();
351     auto strideTosaAttr = op->getAttr("stride").cast<ArrayAttr>();
352     auto dilationTosaAttr = op->getAttr("dilation").cast<ArrayAttr>();
353 
354     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
355       return rewriter.notifyMatchFailure(
356           op, "tosa.depthwise_conv ops require static shapes");
357 
358     // Compute output dynamic dims
359     SmallVector<Value> filteredDims = inferDynamicDimsForConv(
360         loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr,
361         0, 1, rewriter);
362 
363     bool isQuantized = op->hasAttr("quantization_info");
364     IntegerAttr iZp;
365     IntegerAttr kZp;
366     if (isQuantized) {
367       auto quantizationInfo =
368           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
369       iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
370       kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
371     }
372 
373     auto weightShape = weightTy.getShape();
374     auto resultShape = resultTy.getShape();
375 
376     // Apply padding as necessary.
377     Attribute zeroAttr = rewriter.getZeroAttr(inputETy);
378     if (isQuantized) {
379       auto quantizationInfo =
380           op->getAttr("quantization_info").cast<tosa::ConvOpQuantizationAttr>();
381       int64_t iZp = quantizationInfo.getInputZp();
382 
383       int64_t intMin =
384           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
385               .getSExtValue();
386       int64_t intMax =
387           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
388               .getSExtValue();
389 
390       if (iZp < intMin || iZp > intMax)
391         return rewriter.notifyMatchFailure(
392             op, "tosa.depthwise_conv op quantization has zp outside of input "
393                 "range");
394 
395       zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
396     }
397 
398     llvm::SmallVector<int64_t> pad;
399     pad.resize(2, 0);
400     getValuesFromIntArrayAttribute(padAttr, pad);
401     pad.resize(pad.size() + 2, 0);
402 
403     input = applyPad(loc, input, pad, zeroAttr, rewriter);
404 
405     // Extract the attributes for convolution.
406     llvm::SmallVector<int64_t> stride, dilation;
407     getValuesFromIntArrayAttribute(strideTosaAttr, stride);
408     getValuesFromIntArrayAttribute(dilationTosaAttr, dilation);
409 
410     // Create the convolution op.
411     auto strideAttr = DenseIntElementsAttr::get(
412         RankedTensorType::get({2}, rewriter.getI64Type()), stride);
413     auto dilationAttr = DenseIntElementsAttr::get(
414         RankedTensorType::get({2}, rewriter.getI64Type()), dilation);
415     ShapedType linalgConvTy =
416         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
417                                weightShape[2], weightShape[3]},
418                               resultETy);
419 
420     // Broadcast the initial value to the output tensor before convolving.
421     SmallVector<AffineMap, 4> indexingMaps;
422     indexingMaps.push_back(AffineMap::get(
423         /*dimCount=*/resultRank, /*symbolCount=*/0,
424         {rewriter.getAffineDimExpr(3)}, rewriter.getContext()));
425     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
426     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
427 
428     Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy);
429     Value initTensor = rewriter.create<linalg::InitTensorOp>(
430         loc, filteredDims, linalgConvTy.getShape(), resultETy);
431     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
432     Value zeroTensor = rewriter
433                            .create<linalg::FillOp>(loc, ValueRange{zero},
434                                                    ValueRange{initTensor})
435                            .result();
436 
437     Value biasInitTensor = rewriter.create<linalg::InitTensorOp>(
438         loc, filteredDims, resultTy.getShape(), resultETy);
439     if (!isQuantized) {
440       Value conv = rewriter
441                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
442                            loc, linalgConvTy, ValueRange{input, weight},
443                            ValueRange{zeroTensor}, strideAttr, dilationAttr)
444                        .getResult(0);
445 
446       SmallVector<ReassociationExprs, 4> reassociationMap;
447       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
448       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
449           loc, resultTy, conv, reassociationMap);
450 
451       Value result =
452           rewriter
453               .create<linalg::GenericOp>(
454                   loc, resultTy, ValueRange({bias, convReshape}),
455                   biasInitTensor, indexingMaps,
456                   getNParallelLoopsAttrs(resultRank),
457                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
458                       ValueRange args) {
459                     Value added = nestedBuilder.create<arith::AddFOp>(
460                         loc, args[0], args[1]);
461                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
462                   })
463               .getResult(0);
464       rewriter.replaceOp(op, result);
465     } else {
466       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
467       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
468       Value conv =
469           rewriter
470               .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
471                   loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
472                   ValueRange{zeroTensor}, strideAttr, dilationAttr)
473               .getResult(0);
474       SmallVector<ReassociationExprs, 4> reassociationMap;
475       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
476       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
477           loc, resultTy, conv, reassociationMap);
478       Value result =
479           rewriter
480               .create<linalg::GenericOp>(
481                   loc, resultTy, ValueRange({bias, convReshape}),
482                   biasInitTensor, indexingMaps,
483                   getNParallelLoopsAttrs(resultRank),
484                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
485                       ValueRange args) {
486                     Value added = nestedBuilder.create<arith::AddIOp>(
487                         loc, args[0], args[1]);
488                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
489                   })
490               .getResult(0);
491       rewriter.replaceOp(op, result);
492     }
493     return success();
494   }
495 };
496 
497 class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
498 public:
499   using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
500   LogicalResult
matchAndRewrite(tosa::MatMulOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const501   matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
502                   ConversionPatternRewriter &rewriter) const final {
503     Location loc = op.getLoc();
504 
505     auto outputTy = op.getType().cast<ShapedType>();
506     auto outputElementTy = outputTy.getElementType();
507 
508     auto firstOperandTy = op->getOperand(0).getType().cast<ShapedType>();
509     auto secondOperandTy = op->getOperand(1).getType().cast<ShapedType>();
510 
511     SmallVector<Value> dynDims;
512     dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
513 
514     if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
515       dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
516     }
517 
518     if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
519       dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
520     }
521 
522     if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
523       dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
524     }
525 
526     SmallVector<Value> filteredDims = condenseValues(dynDims);
527 
528     auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
529     Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
530     auto initTensor = rewriter.create<linalg::InitTensorOp>(
531         loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
532     Value zeroTensor = rewriter
533                            .create<linalg::FillOp>(loc, ValueRange{zero},
534                                                    ValueRange{initTensor})
535                            .result();
536     if (!op.getQuantizationInfo()) {
537       rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
538           op, TypeRange{op.getType()},
539           ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
540       return success();
541     }
542 
543     auto quantizationInfo = *op.getQuantizationInfo();
544     auto aZp = rewriter.create<arith::ConstantOp>(
545         loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
546     auto bZp = rewriter.create<arith::ConstantOp>(
547         loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
548     rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
549         op, TypeRange{op.getType()},
550         ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
551 
552     return success();
553   }
554 };
555 
556 class FullyConnectedConverter
557     : public OpConversionPattern<tosa::FullyConnectedOp> {
558 public:
559   using OpConversionPattern<tosa::FullyConnectedOp>::OpConversionPattern;
560   LogicalResult
matchAndRewrite(tosa::FullyConnectedOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const561   matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor,
562                   ConversionPatternRewriter &rewriter) const final {
563     Location loc = op.getLoc();
564     auto outputTy = op.getType().cast<ShapedType>();
565     auto input = op.getInput();
566     auto inputTy = input.getType().cast<ShapedType>();
567 
568     auto bias = op.getBias();
569 
570     auto weight = op.getWeight();
571     auto weightTy = weight.getType().cast<ShapedType>();
572     auto weightShape = weightTy.getShape();
573 
574     auto outputETy = outputTy.getElementType();
575 
576     SmallVector<Value> dynDims;
577     dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
578 
579     if (!inputTy.hasRank() || inputTy.isDynamicDim(0)) {
580       dynDims[0] = rewriter.create<tensor::DimOp>(loc, input, 0);
581     }
582 
583     if (!weightTy.hasRank() || weightTy.isDynamicDim(0)) {
584       dynDims[1] = rewriter.create<tensor::DimOp>(loc, weight, 0);
585     }
586 
587     SmallVector<Value> filteredDims = condenseValues(dynDims);
588 
589     // Creating maps for the output of MatMul and the bias
590     SmallVector<AffineMap, 4> indexingMaps;
591 
592     // Broadcast the bias.
593     indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
594                                           {rewriter.getAffineDimExpr(1)},
595                                           rewriter.getContext()));
596 
597     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
598     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
599 
600     auto initTensor = rewriter.create<linalg::InitTensorOp>(
601         loc, filteredDims, outputTy.getShape(), outputTy.getElementType());
602 
603     // When quantized, the input elemeny type is not the same as the output
604     Attribute resultZeroAttr = rewriter.getZeroAttr(outputETy);
605     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
606     Value zeroTensor = rewriter
607                            .create<linalg::FillOp>(loc, ValueRange{zero},
608                                                    ValueRange{initTensor})
609                            .result();
610 
611     SmallVector<int64_t> permutation{1, 0};
612     auto permutationAttr = DenseIntElementsAttr::get(
613         RankedTensorType::get({2}, rewriter.getI64Type()), permutation);
614     Value permutationValue =
615         rewriter.create<arith::ConstantOp>(loc, permutationAttr);
616 
617     SmallVector<int64_t> newWeightShape{weightShape[1], weightShape[0]};
618     Type newWeightTy =
619         RankedTensorType::get(newWeightShape, weightTy.getElementType());
620 
621     Value transposedWeight = rewriter.create<tosa::TransposeOp>(
622         loc, newWeightTy, weight, permutationValue);
623 
624     auto biasInitTensor =
625         rewriter
626             .create<linalg::InitTensorOp>(loc, filteredDims,
627                                           outputTy.getShape(), outputETy)
628             ->getResults();
629 
630     if (!op.getQuantizationInfo()) {
631       Value matmul = rewriter
632                          .create<linalg::MatmulOp>(
633                              loc, TypeRange{op.getType()},
634                              ValueRange{input, transposedWeight}, zeroTensor)
635                          ->getResult(0);
636 
637       Value result =
638           rewriter
639               .create<linalg::GenericOp>(
640                   loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
641                   indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
642                   [&](OpBuilder &nestedBuilder, Location nestedLoc,
643                       ValueRange args) {
644                     Value added = nestedBuilder.create<arith::AddFOp>(
645                         loc, args[0], args[1]);
646                     nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
647                   })
648               .getResult(0);
649       rewriter.replaceOp(op, result);
650       return success();
651     }
652 
653     auto quantizationInfo = *op.getQuantizationInfo();
654     auto inputZp = rewriter.create<arith::ConstantOp>(
655         loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()));
656     auto outputZp = rewriter.create<arith::ConstantOp>(
657         loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()));
658     Value matmul =
659         rewriter
660             .create<linalg::QuantizedMatmulOp>(
661                 loc, TypeRange{op.getType()},
662                 ValueRange{input, transposedWeight, inputZp, outputZp},
663                 zeroTensor)
664             ->getResult(0);
665     Value result =
666         rewriter
667             .create<linalg::GenericOp>(
668                 loc, outputTy, ValueRange({bias, matmul}), biasInitTensor,
669                 indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
670                 [&](OpBuilder &nestedBuilder, Location nestedLoc,
671                     ValueRange args) {
672                   Value added = nestedBuilder.create<arith::AddIOp>(
673                       loc, args[0], args[1]);
674                   nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
675                 })
676             .getResult(0);
677     rewriter.replaceOp(op, result);
678     return success();
679   }
680 };
681 
682 class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
683 public:
684   using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
685 
matchAndRewrite(tosa::MaxPool2dOp op,PatternRewriter & rewriter) const686   LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
687                                 PatternRewriter &rewriter) const final {
688     Location loc = op.getLoc();
689     Value input = op.getInput();
690     ShapedType inputTy = input.getType().cast<ShapedType>();
691 
692     ShapedType resultTy = op.getType().template cast<ShapedType>();
693     Type resultETy = inputTy.getElementType();
694 
695     auto dynamicDimsOr =
696         checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
697     if (!dynamicDimsOr.has_value())
698       return failure();
699     SmallVector<Value> dynamicDims = dynamicDimsOr.value();
700 
701     // Determine what the initial value needs to be for the max pool op.
702     Attribute initialAttr;
703     if (resultETy.isF32())
704       initialAttr = rewriter.getFloatAttr(
705           resultETy,
706           APFloat::getLargest(resultETy.cast<FloatType>().getFloatSemantics(),
707                               true));
708 
709     if (resultETy.isa<IntegerType>())
710       initialAttr = rewriter.getIntegerAttr(
711           resultETy,
712           APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
713 
714     if (!initialAttr)
715       return rewriter.notifyMatchFailure(
716           op, "Unsupported initial value for tosa.maxpool_2d op");
717 
718     // Apply padding as necessary.
719     llvm::SmallVector<int64_t> pad;
720     pad.resize(2, 0);
721     getValuesFromIntArrayAttribute(op.getPad(), pad);
722     pad.resize(pad.size() + 2, 0);
723     Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
724 
725     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
726 
727     SmallVector<int64_t> kernel, stride;
728     getValuesFromIntArrayAttribute(op.getKernel(), kernel);
729     getValuesFromIntArrayAttribute(op.getStride(), stride);
730 
731     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
732     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
733 
734     // Create the linalg op that performs pooling.
735     Value initTensor = rewriter.create<linalg::InitTensorOp>(
736         loc, dynamicDims, resultTy.getShape(), resultTy.getElementType());
737 
738     Value filledInitTensor =
739         rewriter
740             .create<linalg::FillOp>(loc, ValueRange{initialValue},
741                                     ValueRange{initTensor})
742             .result();
743 
744     Value fakeWindowDims =
745         rewriter.create<linalg::InitTensorOp>(loc, kernel, resultETy);
746 
747     rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(
748         op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
749         filledInitTensor, strideAttr, dilationAttr);
750     return success();
751   }
752 };
753 
754 class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
755 public:
756   using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
757 
matchAndRewrite(tosa::AvgPool2dOp op,PatternRewriter & rewriter) const758   LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
759                                 PatternRewriter &rewriter) const final {
760     Location loc = op.getLoc();
761     Value input = op.getInput();
762     ShapedType inputTy = input.getType().cast<ShapedType>();
763     Type inElementTy = inputTy.getElementType();
764 
765     ShapedType resultTy = op.getType().template cast<ShapedType>();
766     Type resultETy = op.getType().cast<ShapedType>().getElementType();
767 
768     Type accETy =
769         inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
770     ShapedType accTy = resultTy.clone(accETy);
771 
772     auto dynamicDimsOr =
773         checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
774     if (!dynamicDimsOr.has_value())
775       return failure();
776     SmallVector<Value> dynamicDims = dynamicDimsOr.value();
777 
778     // Apply padding as necessary.
779     llvm::SmallVector<int64_t> pad;
780     pad.resize(2, 0);
781     getValuesFromIntArrayAttribute(op.getPad(), pad);
782     pad.resize(pad.size() + 2, 0);
783     Attribute padAttr = rewriter.getZeroAttr(inElementTy);
784     Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
785 
786     Attribute initialAttr = rewriter.getZeroAttr(accETy);
787     Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
788 
789     SmallVector<int64_t> kernel, stride;
790     getValuesFromIntArrayAttribute(op.getKernel(), kernel);
791     getValuesFromIntArrayAttribute(op.getStride(), stride);
792 
793     Attribute strideAttr = rewriter.getI64VectorAttr(stride);
794     Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
795 
796     // Create the linalg op that performs pooling.
797     Value poolInitTensor = rewriter.create<linalg::InitTensorOp>(
798         loc, dynamicDims, accTy.getShape(), accETy);
799 
800     Value filledInitTensor =
801         rewriter
802             .create<linalg::FillOp>(loc, ValueRange{initialValue},
803                                     ValueRange{poolInitTensor})
804             .result();
805 
806     Value fakeWindowDims =
807         rewriter.create<linalg::InitTensorOp>(loc, kernel, accETy);
808 
809     // Sum across the pooled region.
810     Value poolingOp = rewriter
811                           .create<linalg::PoolingNhwcSumOp>(
812                               loc, ArrayRef<Type>{accTy},
813                               ValueRange{paddedInput, fakeWindowDims},
814                               filledInitTensor, strideAttr, dilationAttr)
815                           .getResult(0);
816 
817     // Normalize the summed value by the number of elements grouped in each
818     // pool.
819     auto poolingOpTy = poolingOp.getType().cast<ShapedType>();
820     auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
821 
822     Value genericInitTensor = rewriter.create<linalg::InitTensorOp>(
823         loc, dynamicDims, resultTy.getShape(), resultETy);
824 
825     auto genericOp = rewriter.create<linalg::GenericOp>(
826         loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
827         ValueRange{genericInitTensor},
828         ArrayRef<AffineMap>({affineMap, affineMap}),
829         getNParallelLoopsAttrs(resultTy.getRank()),
830         [&](OpBuilder &b, Location loc, ValueRange args) {
831           auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
832           auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
833           auto iH = rewriter.create<arith::ConstantIndexOp>(
834               loc, poolingOpTy.getDimSize(1) - 1);
835           auto iW = rewriter.create<arith::ConstantIndexOp>(
836               loc, poolingOpTy.getDimSize(2) - 1);
837 
838           // Compute the indices from either end.
839           auto y0 = rewriter.create<linalg::IndexOp>(loc, 1);
840           auto x0 = rewriter.create<linalg::IndexOp>(loc, 2);
841           auto y1 = rewriter.create<arith::SubIOp>(loc, iH, y0);
842           auto x1 = rewriter.create<arith::SubIOp>(loc, iW, x0);
843 
844           // Determines what the portion of valid input is covered by the
845           // kernel.
846           auto padFn = [&](Value v, Value x, int64_t pad) -> Value {
847             if (pad == 0)
848               return v;
849 
850             auto padVal = rewriter.create<arith::ConstantIndexOp>(loc, pad);
851             Value dx = rewriter.create<arith::SubIOp>(loc, x, padVal);
852 
853             Value cmp = rewriter.create<arith::CmpIOp>(
854                 loc, arith::CmpIPredicate::slt, dx, zero);
855             Value offset = rewriter.create<arith::SelectOp>(loc, cmp, dx, zero);
856             return rewriter.create<arith::AddIOp>(loc, v, offset)->getResult(0);
857           };
858 
859           // Compute the vertical component of coverage.
860           auto kH0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[0]);
861           auto kH1 = padFn(kH0, y0, pad[2]);
862           auto kH2 = padFn(kH1, y1, pad[3]);
863           auto kHCmp = rewriter.create<arith::CmpIOp>(
864               loc, arith::CmpIPredicate::slt, kH2, one);
865           auto kH3 = rewriter.create<arith::SelectOp>(loc, kHCmp, one, kH2);
866 
867           // compute the horizontal component of coverage.
868           auto kW0 = rewriter.create<arith::ConstantIndexOp>(loc, kernel[1]);
869           auto kW1 = padFn(kW0, x0, pad[4]);
870           auto kW2 = padFn(kW1, x1, pad[5]);
871           auto kWCmp = rewriter.create<arith::CmpIOp>(
872               loc, arith::CmpIPredicate::slt, kW2, one);
873           auto kW3 = rewriter.create<arith::SelectOp>(loc, kWCmp, one, kW2);
874 
875           // Compute the total number of elements and normalize.
876           Value count = rewriter.create<arith::MulIOp>(loc, kH3, kW3);
877           auto countI = rewriter.create<arith::IndexCastOp>(
878               loc, rewriter.getI32Type(), count);
879 
880           // Divide by the number of summed values. For floats this is just
881           // a div however for quantized values input normalization had
882           // to be applied.
883           Value poolVal = args[0];
884           if (accETy.isa<FloatType>()) {
885             auto countF = rewriter.create<arith::SIToFPOp>(loc, accETy, countI);
886             poolVal = rewriter.create<arith::DivFOp>(loc, poolVal, countF)
887                           ->getResult(0);
888           } else {
889 
890             // If we have quantization information we need to apply an offset
891             // for the input zp value.
892             if (op.getQuantizationInfo()) {
893               auto quantizationInfo = *op.getQuantizationInfo();
894               auto inputZp = rewriter.create<arith::ConstantOp>(
895                   loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp()));
896               Value offset =
897                   rewriter.create<arith::MulIOp>(loc, accETy, countI, inputZp);
898               poolVal =
899                   rewriter.create<arith::SubIOp>(loc, accETy, poolVal, offset);
900             }
901 
902             // Compute the multiplier and shift values for the quantization
903             // normalization. Preferably we would want to compute more bits
904             // however 32-bits should be enough for compute. Honestly we
905             // should probably straight divide.
906             int64_t numerator = ((1 << 30) + 1);
907             int64_t shift = 30;
908 
909             Value numeratorVal = rewriter.create<arith::ConstantOp>(
910                 loc, rewriter.getI32IntegerAttr(numerator));
911             Value multiplierVal =
912                 rewriter
913                     .create<arith::DivUIOp>(loc, rewriter.getI32Type(),
914                                             numeratorVal, countI)
915                     .getResult();
916             Value shiftVal = rewriter.create<arith::ConstantOp>(
917                 loc, rewriter.getI8IntegerAttr(shift));
918 
919             auto scaled =
920                 rewriter
921                     .create<tosa::ApplyScaleOp>(
922                         loc, rewriter.getI32Type(), poolVal, multiplierVal,
923                         shiftVal, rewriter.getBoolAttr(false))
924                     .getResult();
925 
926             // If we have quantization information we need to apply output
927             // zeropoint.
928             if (op.getQuantizationInfo()) {
929               auto quantizationInfo = *op.getQuantizationInfo();
930               auto outputZp = rewriter.create<arith::ConstantOp>(
931                   loc, b.getIntegerAttr(scaled.getType(),
932                                         quantizationInfo.getOutputZp()));
933               scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp)
934                            .getResult();
935             }
936 
937             // Apply Clip.
938             int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
939 
940             auto min = rewriter.create<arith::ConstantIntOp>(
941                 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
942                 accETy);
943             auto max = rewriter.create<arith::ConstantIntOp>(
944                 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
945                 accETy);
946             auto clamp = clampHelper<arith::CmpIOp>(
947                 loc, scaled, min, max, arith::CmpIPredicate::slt, rewriter);
948 
949             poolVal = clamp;
950             // Convert type.
951             if (resultETy != clamp.getType()) {
952               poolVal =
953                   rewriter.create<arith::TruncIOp>(loc, resultETy, poolVal);
954             }
955           }
956 
957           rewriter.create<linalg::YieldOp>(loc, poolVal);
958         });
959 
960     rewriter.replaceOp(op, genericOp.getResult(0));
961     return success();
962   }
963 };
964 
965 } // namespace
966 
populateTosaToLinalgNamedConversionPatterns(RewritePatternSet * patterns)967 void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
968     RewritePatternSet *patterns) {
969   patterns->add<
970       // clang-format off
971       ConvConverter,
972       DepthwiseConvConverter,
973       MatMulConverter,
974       MaxPool2dConverter,
975       AvgPool2dConverter,
976       FullyConnectedConverter>(patterns->getContext());
977   // clang-format on
978 }
979