1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically 10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping 11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D 12 // including transposing/reversing/reshaping etc.. 13 // of the weights and input/output tenors and reversing/reshaping etc .. of 14 // the weights 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace mlir; 24 using namespace mlir::tosa; 25 26 namespace { 27 28 template <typename T> 29 static void getValuesFromIntArrayAttribute(ArrayAttr attr, 30 SmallVector<T> &arrayValues) { 31 for (Attribute val : attr.getValue()) { 32 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue()); 33 } 34 } 35 36 template <typename TosaOp, typename... Args> 37 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy, 38 Args &&...args) { 39 auto op = rewriter.create<TosaOp>(loc, resultTy, args...); 40 41 InferShapedTypeOpInterface shapeInterface = 42 dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); 43 if (!shapeInterface) 44 return op; 45 46 SmallVector<ShapedTypeComponents> returnedShapes; 47 if (shapeInterface 48 .inferReturnTypeComponents(op.getContext(), op.getLoc(), 49 op->getOperands(), op->getAttrDictionary(), 50 op->getRegions(), returnedShapes) 51 .failed()) 52 return op; 53 54 // We need to use the element type of the existing result type to generate 55 // the new result shaped type. This is because rescale can include a cast to 56 // different bit-width types and does not have a TypeAttr to define the 57 // target type. 58 auto result = op->getResult(0); 59 auto predictedShape = returnedShapes[0]; 60 auto currentKnowledge = 61 mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy); 62 63 // Compute the knowledge based on the inferred type. 64 auto inferredKnowledge = 65 mlir::tosa::ValueKnowledge::getPessimisticValueState(); 66 inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType(); 67 inferredKnowledge.hasRank = predictedShape.hasRank(); 68 if (predictedShape.hasRank()) { 69 for (auto dim : predictedShape.getDims()) { 70 inferredKnowledge.sizes.push_back(dim); 71 } 72 } 73 74 // Compute the new type based on the joined version. 75 auto newKnowledge = 76 mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge); 77 auto newTy = newKnowledge.getType(); 78 result.setType(newTy); 79 return op; 80 } 81 82 class TransposeConvNonStridedConverter 83 : public OpRewritePattern<tosa::TransposeConv2DOp> { 84 public: 85 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 86 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 87 PatternRewriter &rewriter) const final { 88 Location loc = op->getLoc(); 89 Value input = op->getOperand(0); 90 Value weight = op->getOperand(1); 91 Value bias = op->getOperand(2); 92 93 ShapedType inputTy = input.getType().cast<ShapedType>(); 94 ShapedType weightTy = weight.getType().cast<ShapedType>(); 95 ShapedType biasTy = bias.getType().cast<ShapedType>(); 96 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 97 98 llvm::SmallVector<int64_t> pad; 99 llvm::SmallVector<int64_t> stride; 100 101 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 102 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 103 104 // If striding is all 1 we can modify padding and reverse the kernel along 105 // the x/y direction to make it a regular convolution. This is much simpler 106 // then handling striding.... 107 if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) 108 return failure(); 109 110 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 111 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 112 return failure(); 113 114 int64_t kernelHeight = weightTy.getDimSize(1); 115 int64_t kernelWidth = weightTy.getDimSize(2); 116 117 llvm::SmallVector<int64_t> convPad(4, 0); 118 convPad[0] = kernelHeight - 1 - pad[0]; 119 convPad[1] = kernelHeight - 1 - pad[1]; 120 convPad[2] = kernelWidth - 1 - pad[2]; 121 convPad[3] = kernelWidth - 1 - pad[3]; 122 123 auto reverse1 = rewriter.create<tosa::ReverseOp>( 124 loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); 125 auto reverse2 = rewriter.create<tosa::ReverseOp>( 126 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); 127 128 Value conv2d; 129 if (op.quantization_info()) { 130 conv2d = rewriter.create<tosa::Conv2DOp>( 131 loc, resultTy, input, reverse2, bias, 132 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), 133 rewriter.getI64ArrayAttr({1, 1}), *op.quantization_info()); 134 } else { 135 conv2d = rewriter.create<tosa::Conv2DOp>( 136 loc, resultTy, input, reverse2, bias, 137 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), 138 rewriter.getI64ArrayAttr({1, 1})); 139 } 140 141 rewriter.replaceOp(op, conv2d); 142 return success(); 143 } 144 }; 145 146 class TransposeConvStridedConverter 147 : public OpRewritePattern<tosa::TransposeConv2DOp> { 148 public: 149 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 150 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 151 PatternRewriter &rewriter) const final { 152 Location loc = op->getLoc(); 153 Value input = op->getOperand(0); 154 Value weight = op->getOperand(1); 155 Value bias = op->getOperand(2); 156 157 ShapedType inputTy = input.getType().cast<ShapedType>(); 158 ShapedType weightTy = weight.getType().cast<ShapedType>(); 159 ShapedType biasTy = bias.getType().cast<ShapedType>(); 160 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 161 162 Type inputETy = inputTy.getElementType(); 163 Type weightETy = weightTy.getElementType(); 164 Type biasETy = biasTy.getElementType(); 165 Type resultETy = resultTy.getElementType(); 166 167 llvm::SmallVector<int64_t> pad; 168 llvm::SmallVector<int64_t> stride; 169 170 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 171 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 172 173 // If striding is all 1 we can modify padding and reverse the kernel along 174 // the x/y direction to make it a regular convolution. This is much simpler 175 // then handling striding.... 176 177 // If strides are all 1 we dont need to use this one. 178 if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) 179 return failure(); 180 181 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 182 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 183 return failure(); 184 185 int64_t batch = inputTy.getDimSize(0); 186 187 int64_t outputChannels = weightTy.getDimSize(0); 188 int64_t weightHeight = weightTy.getDimSize(1); 189 int64_t weightWidth = weightTy.getDimSize(2); 190 int64_t inputChannels = weightTy.getDimSize(3); 191 192 // Pad the weight so that it is modulo of the striding. 193 llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 194 weightPadding[3] = 195 weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0; 196 weightPadding[5] = 197 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; 198 DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( 199 RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); 200 Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>( 201 rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); 202 203 if (op.quantization_info().hasValue()) { 204 auto quantInfo = op.quantization_info().getValue(); 205 weight = createOpAndInfer<tosa::PadOp>( 206 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 207 weightPaddingVal, nullptr, 208 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp())); 209 210 } else { 211 weight = createOpAndInfer<tosa::PadOp>(rewriter, loc, 212 UnrankedTensorType::get(weightETy), 213 weight, weightPaddingVal); 214 } 215 216 weightTy = weight.getType().cast<ShapedType>(); 217 weightHeight = weightTy.getDimSize(1); 218 weightWidth = weightTy.getDimSize(2); 219 220 // Split out the width / height by the stride dimensions. 221 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { 222 outputChannels, weightHeight / stride[0], 223 stride[0], weightWidth / stride[1], 224 stride[1], inputChannels}; 225 weight = createOpAndInfer<tosa::ReshapeOp>( 226 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 227 rewriter.getI64ArrayAttr(weightReshapeDims0)); 228 229 // Transpose the factored-out stride to the output channels. 230 Value transposeWeightVal = rewriter.create<tosa::ConstOp>( 231 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 232 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); 233 234 weight = createOpAndInfer<tosa::TransposeOp>( 235 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 236 transposeWeightVal); 237 238 // Collapse the strides and output channels into a single dimension. 239 llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { 240 outputChannels * stride[0] * stride[1], weightHeight / stride[0], 241 weightWidth / stride[1], inputChannels}; 242 weight = createOpAndInfer<tosa::ReshapeOp>( 243 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 244 rewriter.getI64ArrayAttr(weightReshapeDims1)); 245 ShapedType restridedWeightTy = weight.getType().cast<ShapedType>(); 246 247 weight = createOpAndInfer<tosa::ReverseOp>( 248 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 249 rewriter.getI64IntegerAttr(1)); 250 weight = createOpAndInfer<tosa::ReverseOp>( 251 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 252 rewriter.getI64IntegerAttr(2)); 253 254 // We need to pad the input far enough that we can pull all values. 255 llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 256 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; 257 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; 258 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; 259 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; 260 261 DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( 262 RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); 263 264 Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>( 265 rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); 266 267 if (op.quantization_info().hasValue()) { 268 auto quantInfo = op.quantization_info().getValue(); 269 input = createOpAndInfer<tosa::PadOp>( 270 rewriter, loc, UnrankedTensorType::get(inputETy), input, 271 inputPaddingVal, nullptr, 272 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp())); 273 } else { 274 input = createOpAndInfer<tosa::PadOp>(rewriter, loc, 275 UnrankedTensorType::get(inputETy), 276 input, inputPaddingVal); 277 } 278 279 // We use a zero bias as we need to broadcast the bias. 280 auto zeroBias = rewriter.create<tosa::ConstOp>( 281 loc, 282 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 283 biasETy), 284 DenseElementsAttr::get( 285 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 286 biasETy), 287 rewriter.getZeroAttr(biasETy))); 288 289 // Perform the convolution using the zero bias. 290 Value conv2d; 291 if (op.quantization_info()) { 292 conv2d = createOpAndInfer<tosa::Conv2DOp>( 293 rewriter, loc, UnrankedTensorType::get(resultETy), input, 294 weight, zeroBias, 295 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), 296 /*stride=*/rewriter.getI64ArrayAttr({1, 1}), 297 /*dilation=*/rewriter.getI64ArrayAttr({1, 1}), 298 *op.quantization_info()) 299 .getResult(); 300 } else { 301 conv2d = createOpAndInfer<tosa::Conv2DOp>( 302 rewriter, loc, UnrankedTensorType::get(resultETy), input, 303 weight, zeroBias, 304 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), 305 /*stride=*/rewriter.getI64ArrayAttr({1, 1}), 306 /*dilation=*/rewriter.getI64ArrayAttr({1, 1})) 307 .getResult(); 308 } 309 310 // Factor the resulting width / height. 311 ShapedType convTy = conv2d.getType().cast<ShapedType>(); 312 Type convETy = convTy.getElementType(); 313 314 int64_t convHeight = convTy.getDimSize(1); 315 int64_t convWidth = convTy.getDimSize(2); 316 317 // Factor striding out of the convolution result. 318 llvm::SmallVector<int64_t, 6> convReshapeDims0 = { 319 batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; 320 conv2d = createOpAndInfer<tosa::ReshapeOp>( 321 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 322 rewriter.getI64ArrayAttr(convReshapeDims0)); 323 324 // Transpose the factored-out stride to the output channels. 325 Value transposeConvVal = rewriter.create<tosa::ConstOp>( 326 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 327 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); 328 329 conv2d = createOpAndInfer<tosa::TransposeOp>( 330 rewriter, loc, UnrankedTensorType::get(convETy), conv2d, 331 transposeConvVal); 332 333 // Fuse striding behavior back into width / height. 334 llvm::SmallVector<int64_t, 6> convReshapeDims1 = { 335 batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; 336 conv2d = createOpAndInfer<tosa::ReshapeOp>( 337 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 338 rewriter.getI64ArrayAttr(convReshapeDims1)); 339 340 // Slice out the final result. 341 llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0}; 342 llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(), 343 resultTy.getShape().begin()); 344 sliceBegin[1] = pad[0]; 345 sliceBegin[2] = pad[2]; 346 347 auto slice = createOpAndInfer<tosa::SliceOp>( 348 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 349 rewriter.getI64ArrayAttr(sliceBegin), 350 rewriter.getI64ArrayAttr(resultTy.getShape())) 351 .getResult(); 352 353 auto addBias = 354 createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias); 355 356 rewriter.replaceOp(op, addBias.getResult()); 357 358 return success(); 359 } 360 }; 361 362 } // namespace 363 364 void mlir::tosa::populateTosaDecomposeTransposeConv( 365 MLIRContext *ctx, RewritePatternSet &patterns) { 366 patterns.add<TransposeConvNonStridedConverter>(ctx); 367 patterns.add<TransposeConvStridedConverter>(ctx); 368 } 369