1 //===- TosaDecomposeTransposeConv.cpp
2 //------------------------------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
11 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
12 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
13 // including transposing/reversing/reshaping etc..
14 //     of the weights and input/output tenors and reversing/reshaping etc .. of
15 //     the weights
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
20 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
21 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
22 #include "mlir/Pass/Pass.h"
23 
24 using namespace mlir;
25 using namespace mlir::tosa;
26 
27 namespace {
28 
29 template <typename T>
30 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
31                                            SmallVector<T> &arrayValues) {
32   for (Attribute val : attr.getValue()) {
33     arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
34   }
35 }
36 
37 template <typename TosaOp, typename... Args>
38 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
39                         Args &&...args) {
40   auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
41 
42   InferShapedTypeOpInterface shapeInterface =
43       dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
44   if (!shapeInterface)
45     return op;
46 
47   SmallVector<ShapedTypeComponents> returnedShapes;
48   if (shapeInterface
49           .inferReturnTypeComponents(op.getContext(), op.getLoc(),
50                                      op->getOperands(), op->getAttrDictionary(),
51                                      op->getRegions(), returnedShapes)
52           .failed())
53     return op;
54 
55   // We need to use the element type of the existing result type to generate
56   // the new result shaped type. This is because rescale can include a cast to
57   // different bit-width types and does not have a TypeAttr to define the
58   // target type.
59   auto result = op->getResult(0);
60   auto predictedShape = returnedShapes[0];
61   auto currentKnowledge =
62       mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
63 
64   // Compute the knowledge based on the inferred type.
65   auto inferredKnowledge =
66       mlir::tosa::ValueKnowledge::getPessimisticValueState();
67   inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
68   inferredKnowledge.hasRank = predictedShape.hasRank();
69   if (predictedShape.hasRank()) {
70     for (auto dim : predictedShape.getDims()) {
71       inferredKnowledge.sizes.push_back(dim);
72     }
73   }
74 
75   // Compute the new type based on the joined version.
76   auto newKnowledge =
77       mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
78   auto newTy = newKnowledge.getType();
79   result.setType(newTy);
80   return op;
81 }
82 
83 class TransposeConvDilatedConverter
84     : public OpRewritePattern<tosa::TransposeConv2DOp> {
85 public:
86   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
87   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
88                                 PatternRewriter &rewriter) const final {
89     Location loc = op->getLoc();
90     Value input = op->getOperand(0);
91     Value weight = op->getOperand(1);
92     Value bias = op->getOperand(2);
93 
94     ShapedType inputTy = input.getType().cast<ShapedType>();
95     ShapedType weightTy = weight.getType().cast<ShapedType>();
96     ShapedType biasTy = bias.getType().cast<ShapedType>();
97     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
98 
99     llvm::SmallVector<int64_t> pad;
100     llvm::SmallVector<int64_t> stride;
101     llvm::SmallVector<int64_t> dilation;
102 
103     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
104     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
105     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
106 
107     // If striding is all 1 we can modify padding and reverse the kernel along
108     // the x/y direction to make it a regular convolution. This is much simpler
109     // then handling striding....
110     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
111       return failure();
112 
113     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
114         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
115       return failure();
116 
117     int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
118     int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
119     int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
120     int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
121 
122     llvm::SmallVector<int64_t> convPad(4, 0);
123     convPad[0] = kernelHeight - 1 - pad[0];
124     convPad[2] = kernelWidth - 1 - pad[1];
125     convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
126     convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
127 
128     auto reverse1 = rewriter.create<tosa::ReverseOp>(
129         loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
130     auto reverse2 = rewriter.create<tosa::ReverseOp>(
131         loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
132 
133     Value conv2d;
134     if (op.quantization_info().hasValue()) {
135       conv2d = rewriter.create<tosa::Conv2DOp>(
136           loc, resultTy, input, reverse2, bias,
137           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
138           rewriter.getI64ArrayAttr(dilation),
139           op.quantization_info().getValue());
140     } else {
141       conv2d = rewriter.create<tosa::Conv2DOp>(
142           loc, resultTy, input, reverse2, bias,
143           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
144           rewriter.getI64ArrayAttr(dilation));
145     }
146 
147     rewriter.replaceOp(op, conv2d);
148     return success();
149   }
150 };
151 
152 class TransposeConvStridedConverter
153     : public OpRewritePattern<tosa::TransposeConv2DOp> {
154 public:
155   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
156   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
157                                 PatternRewriter &rewriter) const final {
158     Location loc = op->getLoc();
159     Value input = op->getOperand(0);
160     Value weight = op->getOperand(1);
161     Value bias = op->getOperand(2);
162 
163     ShapedType inputTy = input.getType().cast<ShapedType>();
164     ShapedType weightTy = weight.getType().cast<ShapedType>();
165     ShapedType biasTy = bias.getType().cast<ShapedType>();
166     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
167 
168     Type inputETy = inputTy.getElementType();
169     Type weightETy = weightTy.getElementType();
170     Type biasETy = biasTy.getElementType();
171     Type resultETy = resultTy.getElementType();
172 
173     llvm::SmallVector<int64_t> pad;
174     llvm::SmallVector<int64_t> stride;
175     llvm::SmallVector<int64_t> dilation;
176 
177     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
178     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
179     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
180 
181     // If striding is all 1 we can modify padding and reverse the kernel along
182     // the x/y direction to make it a regular convolution. This is much simpler
183     // then handling striding....
184     if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
185       return failure();
186 
187     // If strides are all 1 we dont need to use this one.
188     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
189       return failure();
190 
191     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
192         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
193       return failure();
194 
195     int64_t batch = inputTy.getDimSize(0);
196 
197     int64_t outputChannels = weightTy.getDimSize(0);
198     int64_t weightHeight = weightTy.getDimSize(1);
199     int64_t weightWidth = weightTy.getDimSize(2);
200     int64_t inputChannels = weightTy.getDimSize(3);
201 
202     // Pad the weight so that it is modulo of the striding.
203     llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
204     weightPadding[3] =
205         weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
206     weightPadding[5] =
207         weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
208     DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
209         RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
210     Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
211         rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
212 
213     if (op.quantization_info().hasValue()) {
214       auto quantInfo = op.quantization_info().getValue();
215       weight = createOpAndInfer<tosa::PadOp>(
216           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
217           weightPaddingVal, nullptr,
218           PadOpQuantizationAttr::get(quantInfo.weight_zp(),
219                                      rewriter.getContext()));
220 
221     } else {
222       weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
223                                              UnrankedTensorType::get(weightETy),
224                                              weight, weightPaddingVal);
225     }
226 
227     weightTy = weight.getType().cast<ShapedType>();
228     weightHeight = weightTy.getDimSize(1);
229     weightWidth = weightTy.getDimSize(2);
230 
231     // Split out the width / height by the stride dimensions.
232     llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
233         outputChannels, weightHeight / stride[0],
234         stride[0],      weightWidth / stride[1],
235         stride[1],      inputChannels};
236     weight = createOpAndInfer<tosa::ReshapeOp>(
237         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
238         rewriter.getI64ArrayAttr(weightReshapeDims0));
239 
240     // Transpose the factored-out stride to the output channels.
241     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
242         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
243         rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
244 
245     weight = createOpAndInfer<tosa::TransposeOp>(
246         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
247         transposeWeightVal);
248 
249     // Collapse the strides and output channels into a single dimension.
250     llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
251         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
252         weightWidth / stride[1], inputChannels};
253     weight = createOpAndInfer<tosa::ReshapeOp>(
254         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
255         rewriter.getI64ArrayAttr(weightReshapeDims1));
256     ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
257 
258     weight = createOpAndInfer<tosa::ReverseOp>(
259         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
260         rewriter.getI64IntegerAttr(1));
261     weight = createOpAndInfer<tosa::ReverseOp>(
262         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
263         rewriter.getI64IntegerAttr(2));
264 
265     // We need to pad the input far enough that we can pull all values.
266     llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
267     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
268     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
269     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
270     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
271 
272     DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
273         RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
274 
275     Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
276         rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
277 
278     if (op.quantization_info().hasValue()) {
279       auto quantInfo = op.quantization_info().getValue();
280       input = createOpAndInfer<tosa::PadOp>(
281           rewriter, loc, UnrankedTensorType::get(inputETy), input,
282           inputPaddingVal, nullptr,
283           PadOpQuantizationAttr::get(quantInfo.input_zp(),
284                                      rewriter.getContext()));
285     } else {
286       input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
287                                             UnrankedTensorType::get(inputETy),
288                                             input, inputPaddingVal);
289     }
290 
291     // We use a zero bias as we need to broadcast the bias.
292     auto zeroBias = rewriter.create<tosa::ConstOp>(
293         loc,
294         RankedTensorType::get({outputChannels * stride[0] * stride[1]},
295                               biasETy),
296         DenseElementsAttr::get(
297             RankedTensorType::get({outputChannels * stride[0] * stride[1]},
298                                   biasETy),
299             rewriter.getZeroAttr(biasETy)));
300 
301     // Perform the convolution using the zero bias.
302     Value conv2d;
303     if (op.quantization_info().hasValue()) {
304       conv2d = createOpAndInfer<tosa::Conv2DOp>(
305                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
306                    weight, zeroBias,
307                    /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
308                    /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
309                    /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
310                    op.quantization_info().getValue())
311                    .getResult();
312     } else {
313       conv2d = createOpAndInfer<tosa::Conv2DOp>(
314                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
315                    weight, zeroBias,
316                    /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
317                    /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
318                    /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
319                    .getResult();
320     }
321 
322     // Factor the resulting width / height.
323     ShapedType convTy = conv2d.getType().cast<ShapedType>();
324     Type convETy = convTy.getElementType();
325 
326     int64_t convHeight = convTy.getDimSize(1);
327     int64_t convWidth = convTy.getDimSize(2);
328 
329     // Factor striding out of the convolution result.
330     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
331         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
332     conv2d = createOpAndInfer<tosa::ReshapeOp>(
333         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
334         rewriter.getI64ArrayAttr(convReshapeDims0));
335 
336     // Transpose the factored-out stride to the output channels.
337     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
338         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
339         rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
340 
341     conv2d = createOpAndInfer<tosa::TransposeOp>(
342         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
343         transposeConvVal);
344 
345     // Fuse striding behavior back into width / height.
346     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
347         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
348     conv2d = createOpAndInfer<tosa::ReshapeOp>(
349         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
350         rewriter.getI64ArrayAttr(convReshapeDims1));
351 
352     // Slice out the final result.
353     llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
354     llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
355                                             resultTy.getShape().begin());
356     sliceBegin[1] = pad[0];
357     sliceBegin[2] = pad[1];
358 
359     auto slice = createOpAndInfer<tosa::SliceOp>(
360                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
361                      rewriter.getI64ArrayAttr(sliceBegin),
362                      rewriter.getI64ArrayAttr(resultTy.getShape()))
363                      .getResult();
364 
365     auto addBias =
366         createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
367 
368     rewriter.replaceOp(op, addBias.getResult());
369 
370     return success();
371   }
372 };
373 
374 } // namespace
375 
376 void mlir::tosa::populateTosaDecomposeTransposeConv(
377     MLIRContext *ctx, RewritePatternSet &patterns) {
378   patterns.insert<TransposeConvDilatedConverter>(ctx);
379   patterns.insert<TransposeConvStridedConverter>(ctx);
380 }
381