1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 //     of the weights and input/output tenors and reversing/reshaping etc .. of
14 //     the weights
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21 #include "mlir/Pass/Pass.h"
22 
23 using namespace mlir;
24 using namespace mlir::tosa;
25 
26 namespace {
27 
28 template <typename T>
29 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
30                                            SmallVector<T> &arrayValues) {
31   for (Attribute val : attr.getValue()) {
32     arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
33   }
34 }
35 
36 template <typename TosaOp, typename... Args>
37 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
38                         Args &&...args) {
39   auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
40 
41   InferShapedTypeOpInterface shapeInterface =
42       dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
43   if (!shapeInterface)
44     return op;
45 
46   SmallVector<ShapedTypeComponents> returnedShapes;
47   if (shapeInterface
48           .inferReturnTypeComponents(op.getContext(), op.getLoc(),
49                                      op->getOperands(), op->getAttrDictionary(),
50                                      op->getRegions(), returnedShapes)
51           .failed())
52     return op;
53 
54   // We need to use the element type of the existing result type to generate
55   // the new result shaped type. This is because rescale can include a cast to
56   // different bit-width types and does not have a TypeAttr to define the
57   // target type.
58   auto result = op->getResult(0);
59   auto predictedShape = returnedShapes[0];
60   auto currentKnowledge =
61       mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
62 
63   // Compute the knowledge based on the inferred type.
64   auto inferredKnowledge =
65       mlir::tosa::ValueKnowledge::getPessimisticValueState();
66   inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
67   inferredKnowledge.hasRank = predictedShape.hasRank();
68   if (predictedShape.hasRank()) {
69     for (auto dim : predictedShape.getDims()) {
70       inferredKnowledge.sizes.push_back(dim);
71     }
72   }
73 
74   // Compute the new type based on the joined version.
75   auto newKnowledge =
76       mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
77   auto newTy = newKnowledge.getType();
78   result.setType(newTy);
79   return op;
80 }
81 
82 class TransposeConvDilatedConverter
83     : public OpRewritePattern<tosa::TransposeConv2DOp> {
84 public:
85   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
86   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
87                                 PatternRewriter &rewriter) const final {
88     Location loc = op->getLoc();
89     Value input = op->getOperand(0);
90     Value weight = op->getOperand(1);
91     Value bias = op->getOperand(2);
92 
93     ShapedType inputTy = input.getType().cast<ShapedType>();
94     ShapedType weightTy = weight.getType().cast<ShapedType>();
95     ShapedType biasTy = bias.getType().cast<ShapedType>();
96     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
97 
98     llvm::SmallVector<int64_t> pad;
99     llvm::SmallVector<int64_t> stride;
100     llvm::SmallVector<int64_t> dilation;
101 
102     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
103     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
104     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
105 
106     // If striding is all 1 we can modify padding and reverse the kernel along
107     // the x/y direction to make it a regular convolution. This is much simpler
108     // then handling striding....
109     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
110       return failure();
111 
112     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
113         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
114       return failure();
115 
116     int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1;
117     int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1;
118     int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1;
119     int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1;
120 
121     llvm::SmallVector<int64_t> convPad(4, 0);
122     convPad[0] = kernelHeight - 1 - pad[0];
123     convPad[2] = kernelWidth - 1 - pad[1];
124     convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1);
125     convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2);
126 
127     auto reverse1 = rewriter.create<tosa::ReverseOp>(
128         loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
129     auto reverse2 = rewriter.create<tosa::ReverseOp>(
130         loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
131 
132     Value conv2d;
133     if (op.quantization_info().hasValue()) {
134       conv2d = rewriter.create<tosa::Conv2DOp>(
135           loc, resultTy, input, reverse2, bias,
136           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
137           rewriter.getI64ArrayAttr(dilation),
138           op.quantization_info().getValue());
139     } else {
140       conv2d = rewriter.create<tosa::Conv2DOp>(
141           loc, resultTy, input, reverse2, bias,
142           rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
143           rewriter.getI64ArrayAttr(dilation));
144     }
145 
146     rewriter.replaceOp(op, conv2d);
147     return success();
148   }
149 };
150 
151 class TransposeConvStridedConverter
152     : public OpRewritePattern<tosa::TransposeConv2DOp> {
153 public:
154   using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
155   LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
156                                 PatternRewriter &rewriter) const final {
157     Location loc = op->getLoc();
158     Value input = op->getOperand(0);
159     Value weight = op->getOperand(1);
160     Value bias = op->getOperand(2);
161 
162     ShapedType inputTy = input.getType().cast<ShapedType>();
163     ShapedType weightTy = weight.getType().cast<ShapedType>();
164     ShapedType biasTy = bias.getType().cast<ShapedType>();
165     ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
166 
167     Type inputETy = inputTy.getElementType();
168     Type weightETy = weightTy.getElementType();
169     Type biasETy = biasTy.getElementType();
170     Type resultETy = resultTy.getElementType();
171 
172     llvm::SmallVector<int64_t> pad;
173     llvm::SmallVector<int64_t> stride;
174     llvm::SmallVector<int64_t> dilation;
175 
176     getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad);
177     getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride);
178     getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation);
179 
180     // If striding is all 1 we can modify padding and reverse the kernel along
181     // the x/y direction to make it a regular convolution. This is much simpler
182     // then handling striding....
183     if (llvm::any_of(dilation, [](int64_t v) { return v != 1; }))
184       return failure();
185 
186     // If strides are all 1 we dont need to use this one.
187     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
188       return failure();
189 
190     if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
191         !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
192       return failure();
193 
194     int64_t batch = inputTy.getDimSize(0);
195 
196     int64_t outputChannels = weightTy.getDimSize(0);
197     int64_t weightHeight = weightTy.getDimSize(1);
198     int64_t weightWidth = weightTy.getDimSize(2);
199     int64_t inputChannels = weightTy.getDimSize(3);
200 
201     // Pad the weight so that it is modulo of the striding.
202     llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
203     weightPadding[3] =
204         weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
205     weightPadding[5] =
206         weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
207     DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
208         RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
209     Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
210         rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
211 
212     if (op.quantization_info().hasValue()) {
213       auto quantInfo = op.quantization_info().getValue();
214       weight = createOpAndInfer<tosa::PadOp>(
215           rewriter, loc, UnrankedTensorType::get(weightETy), weight,
216           weightPaddingVal, nullptr,
217           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeight_zp()));
218 
219     } else {
220       weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
221                                              UnrankedTensorType::get(weightETy),
222                                              weight, weightPaddingVal);
223     }
224 
225     weightTy = weight.getType().cast<ShapedType>();
226     weightHeight = weightTy.getDimSize(1);
227     weightWidth = weightTy.getDimSize(2);
228 
229     // Split out the width / height by the stride dimensions.
230     llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
231         outputChannels, weightHeight / stride[0],
232         stride[0],      weightWidth / stride[1],
233         stride[1],      inputChannels};
234     weight = createOpAndInfer<tosa::ReshapeOp>(
235         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
236         rewriter.getI64ArrayAttr(weightReshapeDims0));
237 
238     // Transpose the factored-out stride to the output channels.
239     Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
240         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
241         rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
242 
243     weight = createOpAndInfer<tosa::TransposeOp>(
244         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
245         transposeWeightVal);
246 
247     // Collapse the strides and output channels into a single dimension.
248     llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
249         outputChannels * stride[0] * stride[1], weightHeight / stride[0],
250         weightWidth / stride[1], inputChannels};
251     weight = createOpAndInfer<tosa::ReshapeOp>(
252         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
253         rewriter.getI64ArrayAttr(weightReshapeDims1));
254     ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
255 
256     weight = createOpAndInfer<tosa::ReverseOp>(
257         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
258         rewriter.getI64IntegerAttr(1));
259     weight = createOpAndInfer<tosa::ReverseOp>(
260         rewriter, loc, UnrankedTensorType::get(weightETy), weight,
261         rewriter.getI64IntegerAttr(2));
262 
263     // We need to pad the input far enough that we can pull all values.
264     llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
265     inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
266     inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
267     inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
268     inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
269 
270     DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
271         RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
272 
273     Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
274         rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
275 
276     if (op.quantization_info().hasValue()) {
277       auto quantInfo = op.quantization_info().getValue();
278       input = createOpAndInfer<tosa::PadOp>(
279           rewriter, loc, UnrankedTensorType::get(inputETy), input,
280           inputPaddingVal, nullptr,
281           rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInput_zp()));
282     } else {
283       input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
284                                             UnrankedTensorType::get(inputETy),
285                                             input, inputPaddingVal);
286     }
287 
288     // We use a zero bias as we need to broadcast the bias.
289     auto zeroBias = rewriter.create<tosa::ConstOp>(
290         loc,
291         RankedTensorType::get({outputChannels * stride[0] * stride[1]},
292                               biasETy),
293         DenseElementsAttr::get(
294             RankedTensorType::get({outputChannels * stride[0] * stride[1]},
295                                   biasETy),
296             rewriter.getZeroAttr(biasETy)));
297 
298     // Perform the convolution using the zero bias.
299     Value conv2d;
300     if (op.quantization_info().hasValue()) {
301       conv2d = createOpAndInfer<tosa::Conv2DOp>(
302                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
303                    weight, zeroBias,
304                    /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
305                    /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
306                    /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
307                    op.quantization_info().getValue())
308                    .getResult();
309     } else {
310       conv2d = createOpAndInfer<tosa::Conv2DOp>(
311                    rewriter, loc, UnrankedTensorType::get(resultETy), input,
312                    weight, zeroBias,
313                    /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
314                    /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
315                    /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
316                    .getResult();
317     }
318 
319     // Factor the resulting width / height.
320     ShapedType convTy = conv2d.getType().cast<ShapedType>();
321     Type convETy = convTy.getElementType();
322 
323     int64_t convHeight = convTy.getDimSize(1);
324     int64_t convWidth = convTy.getDimSize(2);
325 
326     // Factor striding out of the convolution result.
327     llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
328         batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
329     conv2d = createOpAndInfer<tosa::ReshapeOp>(
330         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
331         rewriter.getI64ArrayAttr(convReshapeDims0));
332 
333     // Transpose the factored-out stride to the output channels.
334     Value transposeConvVal = rewriter.create<tosa::ConstOp>(
335         loc, RankedTensorType::get({6}, rewriter.getI32Type()),
336         rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
337 
338     conv2d = createOpAndInfer<tosa::TransposeOp>(
339         rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
340         transposeConvVal);
341 
342     // Fuse striding behavior back into width / height.
343     llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
344         batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
345     conv2d = createOpAndInfer<tosa::ReshapeOp>(
346         rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
347         rewriter.getI64ArrayAttr(convReshapeDims1));
348 
349     // Slice out the final result.
350     llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
351     llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
352                                             resultTy.getShape().begin());
353     sliceBegin[1] = pad[0];
354     sliceBegin[2] = pad[1];
355 
356     auto slice = createOpAndInfer<tosa::SliceOp>(
357                      rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
358                      rewriter.getI64ArrayAttr(sliceBegin),
359                      rewriter.getI64ArrayAttr(resultTy.getShape()))
360                      .getResult();
361 
362     auto addBias =
363         createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
364 
365     rewriter.replaceOp(op, addBias.getResult());
366 
367     return success();
368   }
369 };
370 
371 } // namespace
372 
373 void mlir::tosa::populateTosaDecomposeTransposeConv(
374     MLIRContext *ctx, RewritePatternSet &patterns) {
375   patterns.add<TransposeConvDilatedConverter>(ctx);
376   patterns.add<TransposeConvStridedConverter>(ctx);
377 }
378