1 //===- TosaDecomposeTransposeConv.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Decompose TOSA TransposeConv operation to a series of TOSA Ops specifically 10 // (1) Convert a Dilated TransposeConv2D to Conv2D including reversing/reshaping 11 // etc.. of the weights (2) Convert a Strided TransposeConv2D to Conv2D 12 // including transposing/reversing/reshaping etc.. 13 // of the weights and input/output tenors and reversing/reshaping etc .. of 14 // the weights 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 19 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 20 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" 21 #include "mlir/Pass/Pass.h" 22 23 using namespace mlir; 24 using namespace mlir::tosa; 25 26 namespace { 27 28 template <typename T> 29 static void getValuesFromIntArrayAttribute(ArrayAttr attr, 30 SmallVector<T> &arrayValues) { 31 for (Attribute val : attr.getValue()) { 32 arrayValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue()); 33 } 34 } 35 36 template <typename TosaOp, typename... Args> 37 TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy, 38 Args &&...args) { 39 auto op = rewriter.create<TosaOp>(loc, resultTy, args...); 40 41 InferShapedTypeOpInterface shapeInterface = 42 dyn_cast<InferShapedTypeOpInterface>(op.getOperation()); 43 if (!shapeInterface) 44 return op; 45 46 SmallVector<ShapedTypeComponents> returnedShapes; 47 if (shapeInterface 48 .inferReturnTypeComponents(op.getContext(), op.getLoc(), 49 op->getOperands(), op->getAttrDictionary(), 50 op->getRegions(), returnedShapes) 51 .failed()) 52 return op; 53 54 // We need to use the element type of the existing result type to generate 55 // the new result shaped type. This is because rescale can include a cast to 56 // different bit-width types and does not have a TypeAttr to define the 57 // target type. 58 auto result = op->getResult(0); 59 auto predictedShape = returnedShapes[0]; 60 auto currentKnowledge = 61 mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy); 62 63 // Compute the knowledge based on the inferred type. 64 auto inferredKnowledge = 65 mlir::tosa::ValueKnowledge::getPessimisticValueState(); 66 inferredKnowledge.dtype = resultTy.cast<ShapedType>().getElementType(); 67 inferredKnowledge.hasRank = predictedShape.hasRank(); 68 if (predictedShape.hasRank()) { 69 for (auto dim : predictedShape.getDims()) { 70 inferredKnowledge.sizes.push_back(dim); 71 } 72 } 73 74 // Compute the new type based on the joined version. 75 auto newKnowledge = 76 mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge); 77 auto newTy = newKnowledge.getType(); 78 result.setType(newTy); 79 return op; 80 } 81 82 class TransposeConvDilatedConverter 83 : public OpRewritePattern<tosa::TransposeConv2DOp> { 84 public: 85 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 86 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 87 PatternRewriter &rewriter) const final { 88 Location loc = op->getLoc(); 89 Value input = op->getOperand(0); 90 Value weight = op->getOperand(1); 91 Value bias = op->getOperand(2); 92 93 ShapedType inputTy = input.getType().cast<ShapedType>(); 94 ShapedType weightTy = weight.getType().cast<ShapedType>(); 95 ShapedType biasTy = bias.getType().cast<ShapedType>(); 96 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 97 98 llvm::SmallVector<int64_t> pad; 99 llvm::SmallVector<int64_t> stride; 100 llvm::SmallVector<int64_t> dilation; 101 102 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 103 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 104 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation); 105 106 // If striding is all 1 we can modify padding and reverse the kernel along 107 // the x/y direction to make it a regular convolution. This is much simpler 108 // then handling striding.... 109 if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) 110 return failure(); 111 112 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 113 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 114 return failure(); 115 116 int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1; 117 int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1; 118 int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1; 119 int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1; 120 121 llvm::SmallVector<int64_t> convPad(4, 0); 122 convPad[0] = kernelHeight - 1 - pad[0]; 123 convPad[2] = kernelWidth - 1 - pad[1]; 124 convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1); 125 convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2); 126 127 auto reverse1 = rewriter.create<tosa::ReverseOp>( 128 loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); 129 auto reverse2 = rewriter.create<tosa::ReverseOp>( 130 loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); 131 132 Value conv2d; 133 if (op.quantization_info().hasValue()) { 134 conv2d = rewriter.create<tosa::Conv2DOp>( 135 loc, resultTy, input, reverse2, bias, 136 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), 137 rewriter.getI64ArrayAttr(dilation), 138 op.quantization_info().getValue()); 139 } else { 140 conv2d = rewriter.create<tosa::Conv2DOp>( 141 loc, resultTy, input, reverse2, bias, 142 rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), 143 rewriter.getI64ArrayAttr(dilation)); 144 } 145 146 rewriter.replaceOp(op, conv2d); 147 return success(); 148 } 149 }; 150 151 class TransposeConvStridedConverter 152 : public OpRewritePattern<tosa::TransposeConv2DOp> { 153 public: 154 using OpRewritePattern<tosa::TransposeConv2DOp>::OpRewritePattern; 155 LogicalResult matchAndRewrite(tosa::TransposeConv2DOp op, 156 PatternRewriter &rewriter) const final { 157 Location loc = op->getLoc(); 158 Value input = op->getOperand(0); 159 Value weight = op->getOperand(1); 160 Value bias = op->getOperand(2); 161 162 ShapedType inputTy = input.getType().cast<ShapedType>(); 163 ShapedType weightTy = weight.getType().cast<ShapedType>(); 164 ShapedType biasTy = bias.getType().cast<ShapedType>(); 165 ShapedType resultTy = op->getResult(0).getType().cast<ShapedType>(); 166 167 Type inputETy = inputTy.getElementType(); 168 Type weightETy = weightTy.getElementType(); 169 Type biasETy = biasTy.getElementType(); 170 Type resultETy = resultTy.getElementType(); 171 172 llvm::SmallVector<int64_t> pad; 173 llvm::SmallVector<int64_t> stride; 174 llvm::SmallVector<int64_t> dilation; 175 176 getValuesFromIntArrayAttribute(op.out_pad().cast<ArrayAttr>(), pad); 177 getValuesFromIntArrayAttribute(op.stride().cast<ArrayAttr>(), stride); 178 getValuesFromIntArrayAttribute(op.dilation().cast<ArrayAttr>(), dilation); 179 180 // If striding is all 1 we can modify padding and reverse the kernel along 181 // the x/y direction to make it a regular convolution. This is much simpler 182 // then handling striding.... 183 if (llvm::any_of(dilation, [](int64_t v) { return v != 1; })) 184 return failure(); 185 186 // If strides are all 1 we dont need to use this one. 187 if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) 188 return failure(); 189 190 if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || 191 !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) 192 return failure(); 193 194 int64_t batch = inputTy.getDimSize(0); 195 196 int64_t outputChannels = weightTy.getDimSize(0); 197 int64_t weightHeight = weightTy.getDimSize(1); 198 int64_t weightWidth = weightTy.getDimSize(2); 199 int64_t inputChannels = weightTy.getDimSize(3); 200 201 // Pad the weight so that it is modulo of the striding. 202 llvm::SmallVector<int32_t, 8> weightPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 203 weightPadding[3] = 204 weightHeight % stride[0] ? stride[0] - weightHeight % stride[0] : 0; 205 weightPadding[5] = 206 weightWidth % stride[1] ? stride[1] - weightWidth % stride[1] : 0; 207 DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get( 208 RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding); 209 Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>( 210 rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr); 211 212 if (op.quantization_info().hasValue()) { 213 auto quantInfo = op.quantization_info().getValue(); 214 weight = createOpAndInfer<tosa::PadOp>( 215 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 216 weightPaddingVal, nullptr, 217 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeight_zp())); 218 219 } else { 220 weight = createOpAndInfer<tosa::PadOp>(rewriter, loc, 221 UnrankedTensorType::get(weightETy), 222 weight, weightPaddingVal); 223 } 224 225 weightTy = weight.getType().cast<ShapedType>(); 226 weightHeight = weightTy.getDimSize(1); 227 weightWidth = weightTy.getDimSize(2); 228 229 // Split out the width / height by the stride dimensions. 230 llvm::SmallVector<int64_t, 6> weightReshapeDims0 = { 231 outputChannels, weightHeight / stride[0], 232 stride[0], weightWidth / stride[1], 233 stride[1], inputChannels}; 234 weight = createOpAndInfer<tosa::ReshapeOp>( 235 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 236 rewriter.getI64ArrayAttr(weightReshapeDims0)); 237 238 // Transpose the factored-out stride to the output channels. 239 Value transposeWeightVal = rewriter.create<tosa::ConstOp>( 240 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 241 rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5})); 242 243 weight = createOpAndInfer<tosa::TransposeOp>( 244 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 245 transposeWeightVal); 246 247 // Collapse the strides and output channels into a single dimension. 248 llvm::SmallVector<int64_t, 6> weightReshapeDims1 = { 249 outputChannels * stride[0] * stride[1], weightHeight / stride[0], 250 weightWidth / stride[1], inputChannels}; 251 weight = createOpAndInfer<tosa::ReshapeOp>( 252 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 253 rewriter.getI64ArrayAttr(weightReshapeDims1)); 254 ShapedType restridedWeightTy = weight.getType().cast<ShapedType>(); 255 256 weight = createOpAndInfer<tosa::ReverseOp>( 257 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 258 rewriter.getI64IntegerAttr(1)); 259 weight = createOpAndInfer<tosa::ReverseOp>( 260 rewriter, loc, UnrankedTensorType::get(weightETy), weight, 261 rewriter.getI64IntegerAttr(2)); 262 263 // We need to pad the input far enough that we can pull all values. 264 llvm::SmallVector<int32_t, 8> inputPadding = {0, 0, 0, 0, 0, 0, 0, 0}; 265 inputPadding[2] += restridedWeightTy.getDimSize(1) - 1; 266 inputPadding[3] += restridedWeightTy.getDimSize(1) - 1; 267 inputPadding[4] += restridedWeightTy.getDimSize(2) - 1; 268 inputPadding[5] += restridedWeightTy.getDimSize(2) - 1; 269 270 DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get( 271 RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding); 272 273 Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>( 274 rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr); 275 276 if (op.quantization_info().hasValue()) { 277 auto quantInfo = op.quantization_info().getValue(); 278 input = createOpAndInfer<tosa::PadOp>( 279 rewriter, loc, UnrankedTensorType::get(inputETy), input, 280 inputPaddingVal, nullptr, 281 rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInput_zp())); 282 } else { 283 input = createOpAndInfer<tosa::PadOp>(rewriter, loc, 284 UnrankedTensorType::get(inputETy), 285 input, inputPaddingVal); 286 } 287 288 // We use a zero bias as we need to broadcast the bias. 289 auto zeroBias = rewriter.create<tosa::ConstOp>( 290 loc, 291 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 292 biasETy), 293 DenseElementsAttr::get( 294 RankedTensorType::get({outputChannels * stride[0] * stride[1]}, 295 biasETy), 296 rewriter.getZeroAttr(biasETy))); 297 298 // Perform the convolution using the zero bias. 299 Value conv2d; 300 if (op.quantization_info().hasValue()) { 301 conv2d = createOpAndInfer<tosa::Conv2DOp>( 302 rewriter, loc, UnrankedTensorType::get(resultETy), input, 303 weight, zeroBias, 304 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), 305 /*stride=*/rewriter.getI64ArrayAttr({1, 1}), 306 /*dilation=*/rewriter.getI64ArrayAttr({1, 1}), 307 op.quantization_info().getValue()) 308 .getResult(); 309 } else { 310 conv2d = createOpAndInfer<tosa::Conv2DOp>( 311 rewriter, loc, UnrankedTensorType::get(resultETy), input, 312 weight, zeroBias, 313 /*pad=*/rewriter.getI64ArrayAttr({0, 0, 0, 0}), 314 /*stride=*/rewriter.getI64ArrayAttr({1, 1}), 315 /*dilation=*/rewriter.getI64ArrayAttr({1, 1})) 316 .getResult(); 317 } 318 319 // Factor the resulting width / height. 320 ShapedType convTy = conv2d.getType().cast<ShapedType>(); 321 Type convETy = convTy.getElementType(); 322 323 int64_t convHeight = convTy.getDimSize(1); 324 int64_t convWidth = convTy.getDimSize(2); 325 326 // Factor striding out of the convolution result. 327 llvm::SmallVector<int64_t, 6> convReshapeDims0 = { 328 batch, convHeight, convWidth, stride[0], stride[1], outputChannels}; 329 conv2d = createOpAndInfer<tosa::ReshapeOp>( 330 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 331 rewriter.getI64ArrayAttr(convReshapeDims0)); 332 333 // Transpose the factored-out stride to the output channels. 334 Value transposeConvVal = rewriter.create<tosa::ConstOp>( 335 loc, RankedTensorType::get({6}, rewriter.getI32Type()), 336 rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5})); 337 338 conv2d = createOpAndInfer<tosa::TransposeOp>( 339 rewriter, loc, UnrankedTensorType::get(convETy), conv2d, 340 transposeConvVal); 341 342 // Fuse striding behavior back into width / height. 343 llvm::SmallVector<int64_t, 6> convReshapeDims1 = { 344 batch, convHeight * stride[0], convWidth * stride[1], outputChannels}; 345 conv2d = createOpAndInfer<tosa::ReshapeOp>( 346 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 347 rewriter.getI64ArrayAttr(convReshapeDims1)); 348 349 // Slice out the final result. 350 llvm::SmallVector<int64_t, 4> sliceBegin = {0, 0, 0, 0}; 351 llvm::SmallVector<int64_t, 4> sliceSize(resultTy.getShape().begin(), 352 resultTy.getShape().begin()); 353 sliceBegin[1] = pad[0]; 354 sliceBegin[2] = pad[1]; 355 356 auto slice = createOpAndInfer<tosa::SliceOp>( 357 rewriter, loc, UnrankedTensorType::get(resultETy), conv2d, 358 rewriter.getI64ArrayAttr(sliceBegin), 359 rewriter.getI64ArrayAttr(resultTy.getShape())) 360 .getResult(); 361 362 auto addBias = 363 createOpAndInfer<tosa::AddOp>(rewriter, loc, op.getType(), slice, bias); 364 365 rewriter.replaceOp(op, addBias.getResult()); 366 367 return success(); 368 } 369 }; 370 371 } // namespace 372 373 void mlir::tosa::populateTosaDecomposeTransposeConv( 374 MLIRContext *ctx, RewritePatternSet &patterns) { 375 patterns.add<TransposeConvDilatedConverter>(ctx); 376 patterns.add<TransposeConvStridedConverter>(ctx); 377 } 378