1 //===- TosaDecomposeTransposeConv.cpp 2 //------------------------------------------===// 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically 11 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping 12 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D 13 // including transposing/reversing/reshaping etc.. 14 // of the weights and input/output tenors and reversing/reshaping etc .. of 15 // the weights 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 20 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 21 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 22 #include "mlir/Pass/Pass.h" 23 24 using namespace mlir; 25 using namespace mlir::tosa; 26 27 namespace { 28 29 template <typename T> 30 static void getValuesFromIntArrayAttribute(ArrayAttr attr, 31 SmallVector<T> &arrayValues) { 32 for (Attribute val : attr.getValue()) { 33 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue()); 34 } 35 } 36 37 template <typename TosaOp, typename... Args> 38 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy, 39 Args &&...args) { 40 auto op = rewriter.create<TosaOp>(loc, resultTy, args...); 41 42 InferShapedTypeOpInterface shapeInterface = 43 dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); 44 if (!shapeInterface) 45 return op; 46 47 SmallVector<ShapedTypeComponents> returnedShapes; 48 if (shapeInterface 49 .inferReturnTypeComponents(op.getContext(), op.getLoc(), 50 op->getOperands(), op->getAttrDictionary(), 51 op->getRegions(), returnedShapes) 52 .failed()) 53 return op; 54 55 // We need to use the element type of the existing result type to generate 56 // the new result shaped type. This is because rescale can include a cast to 57 // different bit-width types and does not have a TypeAttr to define the 58 // target type. 59 auto result = op->getResult(0); 60 auto predictedShape = returnedShapes[0]; 61 auto currentKnowledge = 62 mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy); 63 64 // Compute the knowledge based on the inferred type. 65 auto inferredKnowledge = 66 mlir::tosa::ValueKnowledge::getPessimisticValueState(); 67 inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType(); 68 inferredKnowledge.hasRank = predictedShape.hasRank(); 69 if (predictedShape.hasRank()) { 70 for (auto dim : predictedShape.getDims()) { 71 inferredKnowledge.sizes.push_back(dim); 72 } 73 } 74 75 // Compute the new type based on the joined version. 76 auto newKnowledge = 77 mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge); 78 auto newTy = newKnowledge.getType(); 79 result.setType(newTy); 80 return op; 81 } 82 83 class TransposeConvDilatedConverter 84 : public OpRewritePattern<tosa::TransposeConv2DOp> { 85 public: 86 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 87 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 88 PatternRewriter &rewriter) const final { 89 Location loc = op->getLoc(); 90 Value input = op->getOperand(0); 91 Value weight = op->getOperand(1); 92 Value bias = op->getOperand(2); 93 94 ShapedType inputTy = input.getType().cast<ShapedType>(); 95 ShapedType weightTy = weight.getType().cast<ShapedType>(); 96 ShapedType biasTy = bias.getType().cast<ShapedType>(); 97 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 98 99 llvm::SmallVector<int64_t> pad; 100 llvm::SmallVector<int64_t> stride; 101 llvm::SmallVector<int64_t> dilation; 102 103 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 104 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 105 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation); 106 107 // If striding is all 1 we can modify padding and reverse the kernel along 108 // the x/y direction to make it a regular convolution. This is much simpler 109 // then handling striding.... 110 if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) 111 return failure(); 112 113 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 114 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 115 return failure(); 116 117 int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1; 118 int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1; 119 int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1; 120 int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1; 121 122 llvm::SmallVector<int64_t> convPad(4, 0); 123 convPad[0] = kernelHeight - 1 - pad[0]; 124 convPad[2] = kernelWidth - 1 - pad[1]; 125 convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1); 126 convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2); 127 128 auto reverse1 = rewriter.create<tosa::ReverseOp>( 129 loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); 130 auto reverse2 = rewriter.create<tosa::ReverseOp>( 131 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); 132 133 Value conv2d; 134 if (op.quantization_info().hasValue()) { 135 conv2d = rewriter.create<tosa::Conv2DOp>( 136 loc, resultTy, input, reverse2, bias, 137 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), 138 rewriter.getI64ArrayAttr(dilation), 139 op.quantization_info().getValue()); 140 } else { 141 conv2d = rewriter.create<tosa::Conv2DOp>( 142 loc, resultTy, input, reverse2, bias, 143 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), 144 rewriter.getI64ArrayAttr(dilation)); 145 } 146 147 rewriter.replaceOp(op, conv2d); 148 return success(); 149 } 150 }; 151 152 class TransposeConvStridedConverter 153 : public OpRewritePattern<tosa::TransposeConv2DOp> { 154 public: 155 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 156 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 157 PatternRewriter &rewriter) const final { 158 Location loc = op->getLoc(); 159 Value input = op->getOperand(0); 160 Value weight = op->getOperand(1); 161 Value bias = op->getOperand(2); 162 163 ShapedType inputTy = input.getType().cast<ShapedType>(); 164 ShapedType weightTy = weight.getType().cast<ShapedType>(); 165 ShapedType biasTy = bias.getType().cast<ShapedType>(); 166 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 167 168 Type inputETy = inputTy.getElementType(); 169 Type weightETy = weightTy.getElementType(); 170 Type biasETy = biasTy.getElementType(); 171 Type resultETy = resultTy.getElementType(); 172 173 llvm::SmallVector<int64_t> pad; 174 llvm::SmallVector<int64_t> stride; 175 llvm::SmallVector<int64_t> dilation; 176 177 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 178 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 179 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation); 180 181 // If striding is all 1 we can modify padding and reverse the kernel along 182 // the x/y direction to make it a regular convolution. This is much simpler 183 // then handling striding.... 184 if (llvm::any_of(dilation, [](int64_t v) { return v != 1; })) 185 return failure(); 186 187 // If strides are all 1 we dont need to use this one. 188 if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) 189 return failure(); 190 191 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 192 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 193 return failure(); 194 195 int64_t batch = inputTy.getDimSize(0); 196 197 int64_t outputChannels = weightTy.getDimSize(0); 198 int64_t weightHeight = weightTy.getDimSize(1); 199 int64_t weightWidth = weightTy.getDimSize(2); 200 int64_t inputChannels = weightTy.getDimSize(3); 201 202 // Pad the weight so that it is modulo of the striding. 203 llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 204 weightPadding[3] = 205 weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0; 206 weightPadding[5] = 207 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; 208 DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( 209 RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); 210 Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>( 211 rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); 212 213 if (op.quantization_info().hasValue()) { 214 auto quantInfo = op.quantization_info().getValue(); 215 weight = createOpAndInfer<tosa::PadOp>( 216 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 217 weightPaddingVal, nullptr, 218 PadOpQuantizationAttr::get(quantInfo.weight_zp(), 219 rewriter.getContext())); 220 221 } else { 222 weight = createOpAndInfer<tosa::PadOp>(rewriter, loc, 223 UnrankedTensorType::get(weightETy), 224 weight, weightPaddingVal); 225 } 226 227 weightTy = weight.getType().cast<ShapedType>(); 228 weightHeight = weightTy.getDimSize(1); 229 weightWidth = weightTy.getDimSize(2); 230 231 // Split out the width / height by the stride dimensions. 232 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { 233 outputChannels, weightHeight / stride[0], 234 stride[0], weightWidth / stride[1], 235 stride[1], inputChannels}; 236 weight = createOpAndInfer<tosa::ReshapeOp>( 237 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 238 rewriter.getI64ArrayAttr(weightReshapeDims0)); 239 240 // Transpose the factored-out stride to the output channels. 241 Value transposeWeightVal = rewriter.create<tosa::ConstOp>( 242 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 243 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); 244 245 weight = createOpAndInfer<tosa::TransposeOp>( 246 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 247 transposeWeightVal); 248 249 // Collapse the strides and output channels into a single dimension. 250 llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { 251 outputChannels * stride[0] * stride[1], weightHeight / stride[0], 252 weightWidth / stride[1], inputChannels}; 253 weight = createOpAndInfer<tosa::ReshapeOp>( 254 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 255 rewriter.getI64ArrayAttr(weightReshapeDims1)); 256 ShapedType restridedWeightTy = weight.getType().cast<ShapedType>(); 257 258 weight = createOpAndInfer<tosa::ReverseOp>( 259 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 260 rewriter.getI64IntegerAttr(1)); 261 weight = createOpAndInfer<tosa::ReverseOp>( 262 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 263 rewriter.getI64IntegerAttr(2)); 264 265 // We need to pad the input far enough that we can pull all values. 266 llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 267 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; 268 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; 269 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; 270 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; 271 272 DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( 273 RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); 274 275 Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>( 276 rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); 277 278 if (op.quantization_info().hasValue()) { 279 auto quantInfo = op.quantization_info().getValue(); 280 input = createOpAndInfer<tosa::PadOp>( 281 rewriter, loc, UnrankedTensorType::get(inputETy), input, 282 inputPaddingVal, nullptr, 283 PadOpQuantizationAttr::get(quantInfo.input_zp(), 284 rewriter.getContext())); 285 } else { 286 input = createOpAndInfer<tosa::PadOp>(rewriter, loc, 287 UnrankedTensorType::get(inputETy), 288 input, inputPaddingVal); 289 } 290 291 // We use a zero bias as we need to broadcast the bias. 292 auto zeroBias = rewriter.create<tosa::ConstOp>( 293 loc, 294 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 295 biasETy), 296 DenseElementsAttr::get( 297 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 298 biasETy), 299 rewriter.getZeroAttr(biasETy))); 300 301 // Perform the convolution using the zero bias. 302 Value conv2d; 303 if (op.quantization_info().hasValue()) { 304 conv2d = createOpAndInfer<tosa::Conv2DOp>( 305 rewriter, loc, UnrankedTensorType::get(resultETy), input, 306 weight, zeroBias, 307 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), 308 /*stride=*/rewriter.getI64ArrayAttr({1, 1}), 309 /*dilation=*/rewriter.getI64ArrayAttr({1, 1}), 310 op.quantization_info().getValue()) 311 .getResult(); 312 } else { 313 conv2d = createOpAndInfer<tosa::Conv2DOp>( 314 rewriter, loc, UnrankedTensorType::get(resultETy), input, 315 weight, zeroBias, 316 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), 317 /*stride=*/rewriter.getI64ArrayAttr({1, 1}), 318 /*dilation=*/rewriter.getI64ArrayAttr({1, 1})) 319 .getResult(); 320 } 321 322 // Factor the resulting width / height. 323 ShapedType convTy = conv2d.getType().cast<ShapedType>(); 324 Type convETy = convTy.getElementType(); 325 326 int64_t convHeight = convTy.getDimSize(1); 327 int64_t convWidth = convTy.getDimSize(2); 328 329 // Factor striding out of the convolution result. 330 llvm::SmallVector<int64_t, 6> convReshapeDims0 = { 331 batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; 332 conv2d = createOpAndInfer<tosa::ReshapeOp>( 333 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 334 rewriter.getI64ArrayAttr(convReshapeDims0)); 335 336 // Transpose the factored-out stride to the output channels. 337 Value transposeConvVal = rewriter.create<tosa::ConstOp>( 338 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 339 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); 340 341 conv2d = createOpAndInfer<tosa::TransposeOp>( 342 rewriter, loc, UnrankedTensorType::get(convETy), conv2d, 343 transposeConvVal); 344 345 // Fuse striding behavior back into width / height. 346 llvm::SmallVector<int64_t, 6> convReshapeDims1 = { 347 batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; 348 conv2d = createOpAndInfer<tosa::ReshapeOp>( 349 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 350 rewriter.getI64ArrayAttr(convReshapeDims1)); 351 352 // Slice out the final result. 353 llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0}; 354 llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(), 355 resultTy.getShape().begin()); 356 sliceBegin[1] = pad[0]; 357 sliceBegin[2] = pad[1]; 358 359 auto slice = createOpAndInfer<tosa::SliceOp>( 360 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 361 rewriter.getI64ArrayAttr(sliceBegin), 362 rewriter.getI64ArrayAttr(resultTy.getShape())) 363 .getResult(); 364 365 auto addBias = 366 createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias); 367 368 rewriter.replaceOp(op, addBias.getResult()); 369 370 return success(); 371 } 372 }; 373 374 } // namespace 375 376 void mlir::tosa::populateTosaDecomposeTransposeConv( 377 MLIRContext *ctx, RewritePatternSet &patterns) { 378 patterns.insert<TransposeConvDilatedConverter>(ctx); 379 patterns.insert<TransposeConvStridedConverter>(ctx); 380 } 381