1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Insert reshape to binary op's input if needed to match rank 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/StandardOps/IR/Ops.h" 14 #include "mlir/Dialect/Tensor/IR/Tensor.h" 15 #include "mlir/Dialect/Tosa/IR/TosaOps.h" 16 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" 17 #include "mlir/Dialect/Tosa/Transforms/Passes.h" 18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" 19 #include "mlir/Pass/Pass.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 using namespace mlir::tosa; 24 25 /// There are two potential ways implementing broadcast: 26 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition 27 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html 28 /// This pass implements b (numpy style) now. 29 30 /// In this pass, we insert RESHAPE operators to increase the rank of the 31 /// lower rank operand as a first step in the broadcasting process. The TOSA 32 /// operators that support broadcast require that the rank of the operands 33 /// are equal. 34 35 // Examples: 36 // If lower=[c], higher=[a, b, c], [c] reshaped into [1, 1, c]. 37 // If lower=[b, c], higher=[a, b, c], [b, c] reshaped into [1, b, c]. 38 // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. 39 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. 40 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. 41 42 static LogicalResult 43 computeReshapeOutput(ArrayRef<int64_t> higherRankShape, 44 ArrayRef<int64_t> lowerRankShape, 45 SmallVectorImpl<int64_t> &reshapeOutputShape) { 46 // Initialize new shapes with [1] * higherRank. 47 int64_t higherRank = higherRankShape.size(); 48 int64_t lowerRank = lowerRankShape.size(); 49 50 reshapeOutputShape.assign(higherRank, 1); 51 52 int64_t higherRankDim; 53 int64_t lowerRankDim; 54 55 for (int64_t i = higherRank - 1, j = lowerRank - 1; i >= 0 && j >= 0; 56 i--, j--) { 57 higherRankDim = higherRankShape[i]; 58 lowerRankDim = lowerRankShape[j]; 59 60 if (lowerRankDim == 1 && higherRankDim > 1) 61 reshapeOutputShape[i] = 1; 62 else if ((lowerRankDim > 1 && higherRankDim == 1) || 63 (lowerRankDim == higherRankDim)) 64 reshapeOutputShape[i] = lowerRankDim; 65 else if (higherRankDim != lowerRankDim) 66 return failure(); 67 } 68 return success(); 69 } 70 71 /// Common code to create the reshape op where necessary to make the rank of the 72 /// operations equal. Returns the updated input1 and input2 for the original 73 /// input. The caller is expected to use these to rewrite the original operator 74 /// with the RESHAPE now in the graph. 75 static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, 76 Location loc, 77 RankedTensorType outputType, 78 Value input1, Value input2, 79 Value &outInput1, Value &outInput2) { 80 auto input1Ty = input1.getType().dyn_cast<RankedTensorType>(); 81 auto input2Ty = input2.getType().dyn_cast<RankedTensorType>(); 82 83 if (!input1Ty || !input2Ty) 84 return failure(); 85 86 int64_t input1Rank = input1Ty.getRank(); 87 int64_t input2Rank = input2Ty.getRank(); 88 89 Value higherTensorValue, lowerTensorValue; 90 // Cannot rewrite as its already correct. 91 if (input1Rank == input2Rank) 92 return failure(); 93 94 if (input1Rank > input2Rank) { 95 higherTensorValue = input1; 96 lowerTensorValue = input2; 97 } else { 98 higherTensorValue = input2; 99 lowerTensorValue = input1; 100 } 101 102 ArrayRef<int64_t> higherRankShape = 103 higherTensorValue.getType().cast<RankedTensorType>().getShape(); 104 (void)higherRankShape; 105 ArrayRef<int64_t> lowerRankShape = 106 lowerTensorValue.getType().cast<RankedTensorType>().getShape(); 107 108 SmallVector<int64_t, 4> reshapeOutputShape; 109 110 if (computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape) 111 .failed()) 112 return failure(); 113 114 auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>(); 115 auto reshapeOutputType = RankedTensorType::get( 116 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType()); 117 118 // Verify the rank agrees with the output type if the output type is ranked. 119 if (outputType) { 120 if (outputType.getShape().size() != reshapeOutputShape.size() || 121 outputType.getShape().size() != higherRankShape.size()) 122 return failure(); 123 } 124 125 auto reshapeLower = rewriter.create<tosa::ReshapeOp>( 126 loc, reshapeOutputType, lowerTensorValue, 127 rewriter.getI64ArrayAttr(reshapeOutputShape)); 128 129 if (input1Rank > input2Rank) { 130 outInput1 = higherTensorValue; 131 outInput2 = reshapeLower.getResult(); 132 } else { 133 outInput1 = reshapeLower.getResult(); 134 outInput2 = higherTensorValue; 135 } 136 137 return success(); 138 } 139 140 namespace { 141 template <typename OpTy> 142 struct ConvertTosaOp : public OpRewritePattern<OpTy> { 143 using OpRewritePattern<OpTy>::OpRewritePattern; 144 145 LogicalResult matchAndRewrite(OpTy tosaBinaryOp, 146 PatternRewriter &rewriter) const override { 147 148 Value input1 = tosaBinaryOp.input1(); 149 Value input2 = tosaBinaryOp.input2(); 150 Value output = tosaBinaryOp.getResult(); 151 152 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 153 if (!outputType) 154 return failure(); 155 156 Value outInput1, outInput2; 157 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 158 input1, input2, outInput1, outInput2) 159 .failed()) 160 return failure(); 161 162 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1, 163 outInput2); 164 165 return success(); 166 } 167 }; 168 169 // The MulOp has an extra parameter 'shift' not present in other elementwise 170 // binary ops, that necessitates special handling of its builder. 171 template <> 172 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> { 173 using OpRewritePattern<tosa::MulOp>::OpRewritePattern; 174 175 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, 176 PatternRewriter &rewriter) const override { 177 178 Value input1 = tosaBinaryOp.input1(); 179 Value input2 = tosaBinaryOp.input2(); 180 int32_t shift = tosaBinaryOp.shift(); 181 Value output = tosaBinaryOp.getResult(); 182 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 183 if (!outputType) 184 return failure(); 185 186 Value outInput1, outInput2; 187 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 188 input1, input2, outInput1, outInput2) 189 .failed()) 190 return failure(); 191 192 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType, 193 outInput1, outInput2, shift); 194 195 return success(); 196 } 197 }; 198 199 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in 200 // other elementwise binary ops, that necessitates special handling of its 201 // builder. 202 template <> 203 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp> 204 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> { 205 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern; 206 207 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, 208 PatternRewriter &rewriter) const override { 209 210 Value input1 = tosaBinaryOp.input1(); 211 Value input2 = tosaBinaryOp.input2(); 212 int32_t round = tosaBinaryOp.round(); 213 Value output = tosaBinaryOp.getResult(); 214 auto outputType = output.getType().dyn_cast<RankedTensorType>(); 215 if (!outputType) 216 return failure(); 217 218 Value outInput1, outInput2; 219 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, 220 input1, input2, outInput1, outInput2) 221 .failed()) 222 return failure(); 223 224 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>( 225 tosaBinaryOp, outputType, outInput1, outInput2, round); 226 227 return success(); 228 } 229 }; 230 } // namespace 231 232 namespace { 233 /// Pass that enables broadcast by making all input arrays have the same 234 /// number of dimensions. Insert RESHAPE operations to lower rank operand 235 struct TosaMakeBroadcastable 236 : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> { 237 public: 238 void runOnFunction() override { 239 auto func = getFunction(); 240 RewritePatternSet patterns(func.getContext()); 241 MLIRContext *ctx = func.getContext(); 242 // Add the generated patterns to the list. 243 patterns.add<ConvertTosaOp<tosa::BitwiseAndOp>>(ctx); 244 patterns.add<ConvertTosaOp<tosa::BitwiseOrOp>>(ctx); 245 patterns.add<ConvertTosaOp<tosa::BitwiseXorOp>>(ctx); 246 patterns.add<ConvertTosaOp<tosa::AddOp>>(ctx); 247 patterns.add<ConvertTosaOp<tosa::SubOp>>(ctx); 248 patterns.add<ConvertTosaOp<tosa::MulOp>>(ctx); 249 patterns.add<ConvertTosaOp<tosa::DivOp>>(ctx); 250 patterns.add<ConvertTosaOp<tosa::MaximumOp>>(ctx); 251 patterns.add<ConvertTosaOp<tosa::MinimumOp>>(ctx); 252 patterns.add<ConvertTosaOp<tosa::EqualOp>>(ctx); 253 patterns.add<ConvertTosaOp<tosa::GreaterOp>>(ctx); 254 patterns.add<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx); 255 patterns.add<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx); 256 patterns.add<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx); 257 patterns.add<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx); 258 patterns.add<ConvertTosaOp<tosa::LogicalAndOp>>(ctx); 259 patterns.add<ConvertTosaOp<tosa::LogicalOrOp>>(ctx); 260 patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx); 261 patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx); 262 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 263 } 264 }; 265 } // namespace 266 267 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() { 268 return std::make_unique<TosaMakeBroadcastable>(); 269 } 270