1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 //     of the weights and input/output tenors and reversing/reshaping etc .. of
14 //     the weights
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 using namespace mlir::tosa;
25 
26 namespace {
27 
28 template <typename T>
getValuesFromIntArrayAttribute(ArrayAttr attr,SmallVector<T> & arrayValues)29 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
30                                            SmallVector<T> &arrayValues) {
31   for (Attribute val : attr.getValue()) {
32     arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
33   }
34 }
35 
36 template <typename TosaOp, typename... Args>
createOpAndInfer(PatternRewriter & rewriter,Location loc,Type resultTy,Args &&...args)37 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
38                         Args &&...args) {
39   auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
40 
41   InferShapedTypeOpInterface shapeInterface =
42       dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
43   if (!shapeInterface)
44     return op;
45 
46   SmallVector<ShapedTypeComponents> returnedShapes;
47   if (shapeInterface
48           .inferReturnTypeComponents(op.getContext(), op.getLoc(),
49                                      op->getOperands(), op->getAttrDictionary(),
50                                      op->getRegions(), returnedShapes)
51           .failed())
52     return op;
53 
54   // We need to use the element type of the existing result type to generate
55   // the new result shaped type. This is because rescale can include a cast to
56   // different bit-width types and does not have a TypeAttr to define the
57   // target type.
58   auto result = op->getResult(0);
59   auto predictedShape = returnedShapes[0];
60   auto currentKnowledge =
61       mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
62 
63   // Compute the knowledge based on the inferred type.
64   auto inferredKnowledge =
65       mlir::tosa::ValueKnowledge::getPessimisticValueState();
66   inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
67   inferredKnowledge.hasRank = predictedShape.hasRank();
68   if (predictedShape.hasRank()) {
69     for (auto dim : predictedShape.getDims()) {
70       inferredKnowledge.sizes.push_back(dim);
71     }
72   }
73 
74   // Compute the new type based on the joined version.
75   auto newKnowledge =
76       mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
77   auto newTy = newKnowledge.getType();
78   result.setType(newTy);
79   return op;
80 }
81 
82 class TransposeConvNonStridedConverter
83     : public OpRewritePattern<tosa::TransposeConv2DOp> {
84 public:
85   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
matchAndRewrite(tosa::TransposeConv2DOp op,PatternRewriter & rewriter) const86   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
87                                 PatternRewriter &rewriter) const final {
88     Location loc = op->getLoc();
89     Value input = op->getOperand(0);
90     Value weight = op->getOperand(1);
91     Value bias = op->getOperand(2);
92 
93     ShapedType inputTy = input.getType().cast<ShapedType>();
94     ShapedType weightTy = weight.getType().cast<ShapedType>();
95     ShapedType biasTy = bias.getType().cast<ShapedType>();
96     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
97 
98     llvm::SmallVector<int64_t> pad;
99     llvm::SmallVector<int64_t> stride;
100 
101     getValuesFromIntArrayAttribute(op.getOutPad().cast<ArrayAttr>(), pad);
102     getValuesFromIntArrayAttribute(op.getStride().cast<ArrayAttr>(), stride);
103 
104     // If striding is all 1 we can modify padding and reverse the kernel along
105     // the x/y direction to make it a regular convolution. This is much simpler
106     // then handling striding....
107     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
108       return failure();
109 
110     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
111         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
112       return failure();
113 
114     int64_t kernelHeight = weightTy.getDimSize(1);
115     int64_t kernelWidth = weightTy.getDimSize(2);
116 
117     llvm::SmallVector<int64_t> convPad(4, 0);
118     convPad[0] = kernelHeight - 1 - pad[0];
119     convPad[1] = kernelHeight - 1 - pad[1];
120     convPad[2] = kernelWidth - 1 - pad[2];
121     convPad[3] = kernelWidth - 1 - pad[3];
122 
123     auto reverse1 = rewriter.create<tosa::ReverseOp>(
124         loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
125     auto reverse2 = rewriter.create<tosa::ReverseOp>(
126         loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
127 
128     Value conv2d;
129     if (op.getQuantizationInfo()) {
130       conv2d = rewriter.create<tosa::Conv2DOp>(
131           loc, resultTy, input, reverse2, bias,
132           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
133           rewriter.getI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
134     } else {
135       conv2d = rewriter.create<tosa::Conv2DOp>(
136           loc, resultTy, input, reverse2, bias,
137           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
138           rewriter.getI64ArrayAttr({1, 1}));
139     }
140 
141     rewriter.replaceOp(op, conv2d);
142     return success();
143   }
144 };
145 
146 class TransposeConvStridedConverter
147     : public OpRewritePattern<tosa::TransposeConv2DOp> {
148 public:
149   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
matchAndRewrite(tosa::TransposeConv2DOp op,PatternRewriter & rewriter) const150   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
151                                 PatternRewriter &rewriter) const final {
152     Location loc = op->getLoc();
153     Value input = op->getOperand(0);
154     Value weight = op->getOperand(1);
155     Value bias = op->getOperand(2);
156 
157     ShapedType inputTy = input.getType().cast<ShapedType>();
158     ShapedType weightTy = weight.getType().cast<ShapedType>();
159     ShapedType biasTy = bias.getType().cast<ShapedType>();
160     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
161 
162     Type inputETy = inputTy.getElementType();
163     Type weightETy = weightTy.getElementType();
164     Type biasETy = biasTy.getElementType();
165     Type resultETy = resultTy.getElementType();
166 
167     llvm::SmallVector<int64_t> pad;
168     llvm::SmallVector<int64_t> stride;
169 
170     getValuesFromIntArrayAttribute(op.getOutPad().cast<ArrayAttr>(), pad);
171     getValuesFromIntArrayAttribute(op.getStride().cast<ArrayAttr>(), stride);
172 
173     // If striding is all 1 we can modify padding and reverse the kernel along
174     // the x/y direction to make it a regular convolution. This is much simpler
175     // then handling striding....
176 
177     // If strides are all 1 we dont need to use this one.
178     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
179       return failure();
180 
181     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
182         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
183       return failure();
184 
185     int64_t batch = inputTy.getDimSize(0);
186 
187     int64_t outputChannels = weightTy.getDimSize(0);
188     int64_t weightHeight = weightTy.getDimSize(1);
189     int64_t weightWidth = weightTy.getDimSize(2);
190     int64_t inputChannels = weightTy.getDimSize(3);
191 
192     // Pad the weight so that it is modulo of the striding.
193     llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
194     weightPadding[3] =
195         weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
196     weightPadding[5] =
197         weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
198     DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
199         RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
200     Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
201         rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
202 
203     if (op.getQuantizationInfo().has_value()) {
204       auto quantInfo = op.getQuantizationInfo().value();
205       weight = createOpAndInfer<tosa::PadOp>(
206           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
207           weightPaddingVal, nullptr,
208           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
209 
210     } else {
211       weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
212                                              UnrankedTensorType::get(weightETy),
213                                              weight, weightPaddingVal);
214     }
215 
216     weightTy = weight.getType().cast<ShapedType>();
217     weightHeight = weightTy.getDimSize(1);
218     weightWidth = weightTy.getDimSize(2);
219 
220     // Split out the width / height by the stride dimensions.
221     llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
222         outputChannels, weightHeight / stride[0],
223         stride[0],      weightWidth / stride[1],
224         stride[1],      inputChannels};
225     weight = createOpAndInfer<tosa::ReshapeOp>(
226         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
227         rewriter.getI64ArrayAttr(weightReshapeDims0));
228 
229     // Transpose the factored-out stride to the output channels.
230     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
231         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
232         rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
233 
234     weight = createOpAndInfer<tosa::TransposeOp>(
235         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
236         transposeWeightVal);
237 
238     // Collapse the strides and output channels into a single dimension.
239     llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
240         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
241         weightWidth / stride[1], inputChannels};
242     weight = createOpAndInfer<tosa::ReshapeOp>(
243         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
244         rewriter.getI64ArrayAttr(weightReshapeDims1));
245     ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
246 
247     weight = createOpAndInfer<tosa::ReverseOp>(
248         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
249         rewriter.getI64IntegerAttr(1));
250     weight = createOpAndInfer<tosa::ReverseOp>(
251         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
252         rewriter.getI64IntegerAttr(2));
253 
254     // We need to pad the input far enough that we can pull all values.
255     llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
256     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
257     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
258     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
259     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
260 
261     DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
262         RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
263 
264     Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
265         rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
266 
267     if (op.getQuantizationInfo().has_value()) {
268       auto quantInfo = op.getQuantizationInfo().value();
269       input = createOpAndInfer<tosa::PadOp>(
270           rewriter, loc, UnrankedTensorType::get(inputETy), input,
271           inputPaddingVal, nullptr,
272           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
273     } else {
274       input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
275                                             UnrankedTensorType::get(inputETy),
276                                             input, inputPaddingVal);
277     }
278 
279     // We use a zero bias as we need to broadcast the bias.
280     auto zeroBias = rewriter.create<tosa::ConstOp>(
281         loc,
282         RankedTensorType::get({outputChannels * stride[0] * stride[1]},
283                               biasETy),
284         DenseElementsAttr::get(
285             RankedTensorType::get({outputChannels * stride[0] * stride[1]},
286                                   biasETy),
287             rewriter.getZeroAttr(biasETy)));
288 
289     // Perform the convolution using the zero bias.
290     Value conv2d;
291     if (op.getQuantizationInfo()) {
292       conv2d = createOpAndInfer<tosa::Conv2DOp>(
293                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
294                    weight, zeroBias,
295                    /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
296                    /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
297                    /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
298                    *op.getQuantizationInfo())
299                    .getResult();
300     } else {
301       conv2d = createOpAndInfer<tosa::Conv2DOp>(
302                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
303                    weight, zeroBias,
304                    /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
305                    /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
306                    /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
307                    .getResult();
308     }
309 
310     // Factor the resulting width / height.
311     ShapedType convTy = conv2d.getType().cast<ShapedType>();
312     Type convETy = convTy.getElementType();
313 
314     int64_t convHeight = convTy.getDimSize(1);
315     int64_t convWidth = convTy.getDimSize(2);
316 
317     // Factor striding out of the convolution result.
318     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
319         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
320     conv2d = createOpAndInfer<tosa::ReshapeOp>(
321         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
322         rewriter.getI64ArrayAttr(convReshapeDims0));
323 
324     // Transpose the factored-out stride to the output channels.
325     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
326         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
327         rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
328 
329     conv2d = createOpAndInfer<tosa::TransposeOp>(
330         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
331         transposeConvVal);
332 
333     // Fuse striding behavior back into width / height.
334     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
335         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
336     conv2d = createOpAndInfer<tosa::ReshapeOp>(
337         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
338         rewriter.getI64ArrayAttr(convReshapeDims1));
339 
340     // Slice out the final result.
341     llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
342     llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
343                                             resultTy.getShape().begin());
344     sliceBegin[1] = pad[0];
345     sliceBegin[2] = pad[2];
346 
347     auto slice = createOpAndInfer<tosa::SliceOp>(
348                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
349                      rewriter.getI64ArrayAttr(sliceBegin),
350                      rewriter.getI64ArrayAttr(resultTy.getShape()))
351                      .getResult();
352 
353     auto addBias =
354         createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
355 
356     rewriter.replaceOp(op, addBias.getResult());
357 
358     return success();
359   }
360 };
361 
362 } // namespace
363 
populateTosaDecomposeTransposeConv(MLIRContext * ctx,RewritePatternSet & patterns)364 void mlir::tosa::populateTosaDecomposeTransposeConv(
365     MLIRContext *ctx, RewritePatternSet &patterns) {
366   patterns.add<TransposeConvNonStridedConverter>(ctx);
367   patterns.add<TransposeConvStridedConverter>(ctx);
368 }
369