1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically
10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping
11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D
12 // including transposing/reversing/reshaping etc..
13 // of the weights and input/output tenors and reversing/reshaping etc .. of
14 // the weights
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
19 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
21 #include "mlir/Pass/Pass.h"
22
23 using namespace mlir;
24 using namespace mlir::tosa;
25
26 namespace {
27
28 template <typename T>
getValuesFromIntArrayAttribute(ArrayAttr attr,SmallVector<T> & arrayValues)29 static void getValuesFromIntArrayAttribute(ArrayAttr attr,
30 SmallVector<T> &arrayValues) {
31 for (Attribute val : attr.getValue()) {
32 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
33 }
34 }
35
36 template <typename TosaOp, typename... Args>
createOpAndInfer(PatternRewriter & rewriter,Location loc,Type resultTy,Args &&...args)37 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
38 Args &&...args) {
39 auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
40
41 InferShapedTypeOpInterface shapeInterface =
42 dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
43 if (!shapeInterface)
44 return op;
45
46 SmallVector<ShapedTypeComponents> returnedShapes;
47 if (shapeInterface
48 .inferReturnTypeComponents(op.getContext(), op.getLoc(),
49 op->getOperands(), op->getAttrDictionary(),
50 op->getRegions(), returnedShapes)
51 .failed())
52 return op;
53
54 // We need to use the element type of the existing result type to generate
55 // the new result shaped type. This is because rescale can include a cast to
56 // different bit-width types and does not have a TypeAttr to define the
57 // target type.
58 auto result = op->getResult(0);
59 auto predictedShape = returnedShapes[0];
60 auto currentKnowledge =
61 mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
62
63 // Compute the knowledge based on the inferred type.
64 auto inferredKnowledge =
65 mlir::tosa::ValueKnowledge::getPessimisticValueState();
66 inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType();
67 inferredKnowledge.hasRank = predictedShape.hasRank();
68 if (predictedShape.hasRank()) {
69 for (auto dim : predictedShape.getDims()) {
70 inferredKnowledge.sizes.push_back(dim);
71 }
72 }
73
74 // Compute the new type based on the joined version.
75 auto newKnowledge =
76 mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
77 auto newTy = newKnowledge.getType();
78 result.setType(newTy);
79 return op;
80 }
81
82 class TransposeConvNonStridedConverter
83 : public OpRewritePattern<tosa::TransposeConv2DOp> {
84 public:
85 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
matchAndRewrite(tosa::TransposeConv2DOp op,PatternRewriter & rewriter) const86 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
87 PatternRewriter &rewriter) const final {
88 Location loc = op->getLoc();
89 Value input = op->getOperand(0);
90 Value weight = op->getOperand(1);
91 Value bias = op->getOperand(2);
92
93 ShapedType inputTy = input.getType().cast<ShapedType>();
94 ShapedType weightTy = weight.getType().cast<ShapedType>();
95 ShapedType biasTy = bias.getType().cast<ShapedType>();
96 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
97
98 llvm::SmallVector<int64_t> pad;
99 llvm::SmallVector<int64_t> stride;
100
101 getValuesFromIntArrayAttribute(op.getOutPad().cast<ArrayAttr>(), pad);
102 getValuesFromIntArrayAttribute(op.getStride().cast<ArrayAttr>(), stride);
103
104 // If striding is all 1 we can modify padding and reverse the kernel along
105 // the x/y direction to make it a regular convolution. This is much simpler
106 // then handling striding....
107 if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
108 return failure();
109
110 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
111 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
112 return failure();
113
114 int64_t kernelHeight = weightTy.getDimSize(1);
115 int64_t kernelWidth = weightTy.getDimSize(2);
116
117 llvm::SmallVector<int64_t> convPad(4, 0);
118 convPad[0] = kernelHeight - 1 - pad[0];
119 convPad[1] = kernelHeight - 1 - pad[1];
120 convPad[2] = kernelWidth - 1 - pad[2];
121 convPad[3] = kernelWidth - 1 - pad[3];
122
123 auto reverse1 = rewriter.create<tosa::ReverseOp>(
124 loc, weightTy, weight, rewriter.getI64IntegerAttr(1));
125 auto reverse2 = rewriter.create<tosa::ReverseOp>(
126 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2));
127
128 Value conv2d;
129 if (op.getQuantizationInfo()) {
130 conv2d = rewriter.create<tosa::Conv2DOp>(
131 loc, resultTy, input, reverse2, bias,
132 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
133 rewriter.getI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
134 } else {
135 conv2d = rewriter.create<tosa::Conv2DOp>(
136 loc, resultTy, input, reverse2, bias,
137 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride),
138 rewriter.getI64ArrayAttr({1, 1}));
139 }
140
141 rewriter.replaceOp(op, conv2d);
142 return success();
143 }
144 };
145
146 class TransposeConvStridedConverter
147 : public OpRewritePattern<tosa::TransposeConv2DOp> {
148 public:
149 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern;
matchAndRewrite(tosa::TransposeConv2DOp op,PatternRewriter & rewriter) const150 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op,
151 PatternRewriter &rewriter) const final {
152 Location loc = op->getLoc();
153 Value input = op->getOperand(0);
154 Value weight = op->getOperand(1);
155 Value bias = op->getOperand(2);
156
157 ShapedType inputTy = input.getType().cast<ShapedType>();
158 ShapedType weightTy = weight.getType().cast<ShapedType>();
159 ShapedType biasTy = bias.getType().cast<ShapedType>();
160 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>();
161
162 Type inputETy = inputTy.getElementType();
163 Type weightETy = weightTy.getElementType();
164 Type biasETy = biasTy.getElementType();
165 Type resultETy = resultTy.getElementType();
166
167 llvm::SmallVector<int64_t> pad;
168 llvm::SmallVector<int64_t> stride;
169
170 getValuesFromIntArrayAttribute(op.getOutPad().cast<ArrayAttr>(), pad);
171 getValuesFromIntArrayAttribute(op.getStride().cast<ArrayAttr>(), stride);
172
173 // If striding is all 1 we can modify padding and reverse the kernel along
174 // the x/y direction to make it a regular convolution. This is much simpler
175 // then handling striding....
176
177 // If strides are all 1 we dont need to use this one.
178 if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
179 return failure();
180
181 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
182 !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
183 return failure();
184
185 int64_t batch = inputTy.getDimSize(0);
186
187 int64_t outputChannels = weightTy.getDimSize(0);
188 int64_t weightHeight = weightTy.getDimSize(1);
189 int64_t weightWidth = weightTy.getDimSize(2);
190 int64_t inputChannels = weightTy.getDimSize(3);
191
192 // Pad the weight so that it is modulo of the striding.
193 llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0};
194 weightPadding[3] =
195 weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0;
196 weightPadding[5] =
197 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0;
198 DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
199 RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
200 Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
201 rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
202
203 if (op.getQuantizationInfo().has_value()) {
204 auto quantInfo = op.getQuantizationInfo().value();
205 weight = createOpAndInfer<tosa::PadOp>(
206 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
207 weightPaddingVal, nullptr,
208 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
209
210 } else {
211 weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
212 UnrankedTensorType::get(weightETy),
213 weight, weightPaddingVal);
214 }
215
216 weightTy = weight.getType().cast<ShapedType>();
217 weightHeight = weightTy.getDimSize(1);
218 weightWidth = weightTy.getDimSize(2);
219
220 // Split out the width / height by the stride dimensions.
221 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = {
222 outputChannels, weightHeight / stride[0],
223 stride[0], weightWidth / stride[1],
224 stride[1], inputChannels};
225 weight = createOpAndInfer<tosa::ReshapeOp>(
226 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
227 rewriter.getI64ArrayAttr(weightReshapeDims0));
228
229 // Transpose the factored-out stride to the output channels.
230 Value transposeWeightVal = rewriter.create<tosa::ConstOp>(
231 loc, RankedTensorType::get({6}, rewriter.getI32Type()),
232 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
233
234 weight = createOpAndInfer<tosa::TransposeOp>(
235 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
236 transposeWeightVal);
237
238 // Collapse the strides and output channels into a single dimension.
239 llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
240 outputChannels * stride[0] * stride[1], weightHeight / stride[0],
241 weightWidth / stride[1], inputChannels};
242 weight = createOpAndInfer<tosa::ReshapeOp>(
243 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
244 rewriter.getI64ArrayAttr(weightReshapeDims1));
245 ShapedType restridedWeightTy = weight.getType().cast<ShapedType>();
246
247 weight = createOpAndInfer<tosa::ReverseOp>(
248 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
249 rewriter.getI64IntegerAttr(1));
250 weight = createOpAndInfer<tosa::ReverseOp>(
251 rewriter, loc, UnrankedTensorType::get(weightETy), weight,
252 rewriter.getI64IntegerAttr(2));
253
254 // We need to pad the input far enough that we can pull all values.
255 llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0};
256 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1;
257 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1;
258 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1;
259 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1;
260
261 DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
262 RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
263
264 Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
265 rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
266
267 if (op.getQuantizationInfo().has_value()) {
268 auto quantInfo = op.getQuantizationInfo().value();
269 input = createOpAndInfer<tosa::PadOp>(
270 rewriter, loc, UnrankedTensorType::get(inputETy), input,
271 inputPaddingVal, nullptr,
272 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
273 } else {
274 input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
275 UnrankedTensorType::get(inputETy),
276 input, inputPaddingVal);
277 }
278
279 // We use a zero bias as we need to broadcast the bias.
280 auto zeroBias = rewriter.create<tosa::ConstOp>(
281 loc,
282 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
283 biasETy),
284 DenseElementsAttr::get(
285 RankedTensorType::get({outputChannels * stride[0] * stride[1]},
286 biasETy),
287 rewriter.getZeroAttr(biasETy)));
288
289 // Perform the convolution using the zero bias.
290 Value conv2d;
291 if (op.getQuantizationInfo()) {
292 conv2d = createOpAndInfer<tosa::Conv2DOp>(
293 rewriter, loc, UnrankedTensorType::get(resultETy), input,
294 weight, zeroBias,
295 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
296 /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
297 /*dilation=*/rewriter.getI64ArrayAttr({1, 1}),
298 *op.getQuantizationInfo())
299 .getResult();
300 } else {
301 conv2d = createOpAndInfer<tosa::Conv2DOp>(
302 rewriter, loc, UnrankedTensorType::get(resultETy), input,
303 weight, zeroBias,
304 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}),
305 /*stride=*/rewriter.getI64ArrayAttr({1, 1}),
306 /*dilation=*/rewriter.getI64ArrayAttr({1, 1}))
307 .getResult();
308 }
309
310 // Factor the resulting width / height.
311 ShapedType convTy = conv2d.getType().cast<ShapedType>();
312 Type convETy = convTy.getElementType();
313
314 int64_t convHeight = convTy.getDimSize(1);
315 int64_t convWidth = convTy.getDimSize(2);
316
317 // Factor striding out of the convolution result.
318 llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
319 batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
320 conv2d = createOpAndInfer<tosa::ReshapeOp>(
321 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
322 rewriter.getI64ArrayAttr(convReshapeDims0));
323
324 // Transpose the factored-out stride to the output channels.
325 Value transposeConvVal = rewriter.create<tosa::ConstOp>(
326 loc, RankedTensorType::get({6}, rewriter.getI32Type()),
327 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
328
329 conv2d = createOpAndInfer<tosa::TransposeOp>(
330 rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
331 transposeConvVal);
332
333 // Fuse striding behavior back into width / height.
334 llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
335 batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
336 conv2d = createOpAndInfer<tosa::ReshapeOp>(
337 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
338 rewriter.getI64ArrayAttr(convReshapeDims1));
339
340 // Slice out the final result.
341 llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0};
342 llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(),
343 resultTy.getShape().begin());
344 sliceBegin[1] = pad[0];
345 sliceBegin[2] = pad[2];
346
347 auto slice = createOpAndInfer<tosa::SliceOp>(
348 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
349 rewriter.getI64ArrayAttr(sliceBegin),
350 rewriter.getI64ArrayAttr(resultTy.getShape()))
351 .getResult();
352
353 auto addBias =
354 createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias);
355
356 rewriter.replaceOp(op, addBias.getResult());
357
358 return success();
359 }
360 };
361
362 } // namespace
363
populateTosaDecomposeTransposeConv(MLIRContext * ctx,RewritePatternSet & patterns)364 void mlir::tosa::populateTosaDecomposeTransposeConv(
365 MLIRContext *ctx, RewritePatternSet &patterns) {
366 patterns.add<TransposeConvNonStridedConverter>(ctx);
367 patterns.add<TransposeConvStridedConverter>(ctx);
368 }
369