1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Quant/QuantOps.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 
28 using namespace mlir;
29 using namespace mlir::tosa;
30 
31 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // Tosa dialect interface includes.
35 //===----------------------------------------------------------------------===//
36 
37 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
38 
39 namespace {
40 //===----------------------------------------------------------------------===//
41 // Dialect Function Inliner Interface.
42 //===----------------------------------------------------------------------===//
43 struct TosaInlinerInterface : public DialectInlinerInterface {
44   using DialectInlinerInterface::DialectInlinerInterface;
45 
46   //===--------------------------------------------------------------------===//
47   // Analysis Hooks.
48   //===--------------------------------------------------------------------===//
49 
50   /// All operations can be inlined by default.
isLegalToInline__anon084804d70111::TosaInlinerInterface51   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
52                        BlockAndValueMapping &map) const final {
53     return true;
54   }
55 
56   /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon084804d70111::TosaInlinerInterface57   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
58                        BlockAndValueMapping &map) const final {
59     return (isa<tosa::IfOp>(dest->getParentOp()) ||
60             isa<tosa::WhileOp>(dest->getParentOp()));
61   }
62 };
63 } // namespace
64 
65 //===----------------------------------------------------------------------===//
66 // TOSA control flow support.
67 //===----------------------------------------------------------------------===//
68 
69 /// Returns the while loop body.
getLoopBody()70 Region &tosa::WhileOp::getLoopBody() { return getBody(); }
71 
72 //===----------------------------------------------------------------------===//
73 // Tosa dialect initialization.
74 //===----------------------------------------------------------------------===//
75 
initialize()76 void TosaDialect::initialize() {
77   addOperations<
78 #define GET_OP_LIST
79 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
80       >();
81   addAttributes<
82 #define GET_ATTRDEF_LIST
83 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
84       >();
85   addInterfaces<TosaInlinerInterface>();
86 }
87 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)88 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
89                                             Type type, Location loc) {
90   // Tosa dialect constants only support ElementsAttr unlike standard dialect
91   // constant which supports all attributes.
92   if (value.isa<ElementsAttr>())
93     return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
94   return nullptr;
95 }
96 
97 //===----------------------------------------------------------------------===//
98 // TOSA Operator Verifiers.
99 //===----------------------------------------------------------------------===//
100 
101 template <typename T>
verifyConvOp(T op)102 static LogicalResult verifyConvOp(T op) {
103   // All TOSA conv ops have an input() and weight().
104   auto inputType =
105       op.getInput().getType().template dyn_cast<RankedTensorType>();
106   auto weightType =
107       op.getWeight().getType().template dyn_cast<RankedTensorType>();
108 
109   // Must be ranked tensor types
110   if (!inputType) {
111     op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
112     return failure();
113   }
114   if (!weightType) {
115     op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
116     return failure();
117   }
118 
119   auto inputEType = inputType.getElementType();
120   auto weightEType = weightType.getElementType();
121 
122   bool inputIsQuant = !inputEType.template isa<FloatType>();
123   bool weightIsQuant = !weightEType.template isa<FloatType>();
124 
125   // Either both must be quantized or both unquantized.
126   if (inputIsQuant != weightIsQuant) {
127     op.emitOpError(
128         "expect both input and weight to be float or not together, got ")
129         << inputEType << " and " << weightEType;
130     return failure();
131   }
132 
133   // Quantized type must have constructed the quantizationattr, and unquantized
134   // types should not have a quantizationattr.
135   if ((inputIsQuant && !op.getQuantizationInfo()) ||
136       (!inputIsQuant && op.getQuantizationInfo())) {
137     op.emitOpError("quantizationattr is required for quantized type, and not "
138                    "allowed for float type");
139     return failure();
140   }
141 
142   return success();
143 }
144 
verify()145 LogicalResult tosa::AvgPool2dOp::verify() {
146   auto inputETy = getInput().getType().cast<ShapedType>().getElementType();
147   auto resultETy = getType().cast<ShapedType>().getElementType();
148 
149   if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
150     inputETy = quantType.getStorageType();
151 
152   if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
153     resultETy = quantType.getStorageType();
154 
155   if (inputETy.isF32() && resultETy.isF32())
156     return success();
157   if (inputETy.isInteger(8) && resultETy.isInteger(8))
158     return success();
159   if (inputETy.isInteger(16) && resultETy.isInteger(16))
160     return success();
161 
162   return emitOpError("input/output element types are incompatible.");
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // TOSA Operator Quantization Builders.
167 //===----------------------------------------------------------------------===//
168 
169 /// This builder is called on all convolution operators except TransposeConv,
170 /// which has specialized output shape semantics. The builder also defines the
171 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)172 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
173                                      Type outputType, Value input, Value weight,
174                                      Value bias, ArrayAttr pad,
175                                      ArrayAttr stride, ArrayAttr dilation) {
176 
177   result.addOperands({input, weight, bias});
178   result.addAttribute("pad", pad);
179   result.addAttribute("stride", stride);
180   result.addAttribute("dilation", dilation);
181 
182   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
183   if (quantAttr) {
184     result.addAttribute("quantization_info", quantAttr);
185     result.addTypes(
186         buildConvOpResultTypeInfo(builder, outputType, input, weight));
187   } else {
188     result.addTypes(outputType);
189   }
190 }
191 
192 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr outputShape)193 static void buildTransConvOpWithQuantInfo(OpBuilder &builder,
194                                           OperationState &result,
195                                           Type outputType, Value input,
196                                           Value weight, Value bias,
197                                           ArrayAttr outpad, ArrayAttr stride,
198                                           ArrayAttr outputShape) {
199   result.addOperands({input, weight, bias});
200   result.addAttribute("out_pad", outpad);
201   result.addAttribute("stride", stride);
202   result.addAttribute("out_shape", outputShape);
203   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
204 
205   if (quantAttr) {
206     result.addAttribute("quantization_info", quantAttr);
207     result.addTypes(
208         buildConvOpResultTypeInfo(builder, outputType, input, weight));
209   } else {
210     result.addTypes(outputType);
211   }
212 }
213 
214 /// The tosa.fully_connected op has its own builder as it does not have
215 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)216 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
217                                    Type outputType, Value input, Value weight,
218                                    Value bias) {
219 
220   result.addOperands({input, weight, bias});
221   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
222   if (quantAttr) {
223     result.addAttribute("quantization_info", quantAttr);
224     result.addTypes(
225         buildConvOpResultTypeInfo(builder, outputType, input, weight));
226   } else {
227     result.addTypes(outputType);
228   }
229 }
230 
231 /// The tosa.matmul op is also intended to be generated where a fully_connected
232 /// op must be constructed where the weight is not a constant. In this case,
233 /// the fully_connected op must be expressed using matmul.
234 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)235 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
236                                        OperationState &result, Type outputType,
237                                        Value a, Value b) {
238   result.addOperands({a, b});
239   auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
240 
241   if (quantAttr) {
242     result.addAttribute("quantization_info", quantAttr);
243 
244     auto inputType = a.getType().dyn_cast<ShapedType>();
245     assert(inputType && "Input must be a shaped tensor type!");
246 
247     auto inputQType = inputType.getElementType()
248                           .dyn_cast<mlir::quant::UniformQuantizedType>();
249     assert(inputQType && "Tensor must have quantized datatype!");
250 
251     unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
252 
253     auto outputShapedType = outputType.dyn_cast<ShapedType>();
254     assert(outputShapedType && "Output must be a shaped type");
255 
256     IntegerType accElementType;
257     if (inputBits == 16)
258       accElementType = builder.getIntegerType(48);
259     else
260       accElementType = builder.getI32Type();
261     auto accType = outputShapedType.clone(accElementType);
262     result.addTypes(accType);
263   } else {
264     result.addTypes(outputType);
265   }
266 }
267 
268 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
269 /// but avg_pool operator has its own builder as it has additional parameters
270 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)271 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
272                                           OperationState &result,
273                                           Type outputType, Value input,
274                                           ArrayAttr kernel, ArrayAttr stride,
275                                           ArrayAttr pad) {
276   result.addOperands(input);
277   result.addAttribute("kernel", kernel);
278   result.addAttribute("stride", stride);
279   result.addAttribute("pad", pad);
280   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
281   if (quantAttr)
282     result.addAttribute("quantization_info", quantAttr);
283   result.types.push_back(outputType);
284 }
285 
286 /// This builder is called on single-parameter unary operators that have scale
287 /// relationship between their input and output, expressed by the
288 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)289 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
290                                       OperationState &result, Type outputType,
291                                       Value input) {
292   result.addOperands(input);
293   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
294   if (quantAttr)
295     result.addAttribute("quantization_info", quantAttr);
296   result.types.push_back(outputType);
297 }
298 
299 /// This builder is called on TOSA pad operator that needs to create its own
300 /// OptionalAttr quantization_attr parameter to scale the padding values
301 /// correctly. No pad_const is interpreted as zero-padding.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)302 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
303                                     Type outputType, Value input,
304                                     Value paddings) {
305   result.addOperands({input, paddings});
306   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
307   if (quantAttr)
308     result.addAttribute("quantization_info", quantAttr);
309   result.types.push_back(outputType);
310 }
311 
312 /// This builder is called on TOSA pad operator when an explicit pad_const
313 /// value is passed in. It also optionally constructs quantization_attr.
buildExplicitValuePadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings,Value padConst)314 static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
315                                                  OperationState &result,
316                                                  Type outputType, Value input,
317                                                  Value paddings,
318                                                  Value padConst) {
319   result.addOperands({input, paddings, padConst});
320   auto quantAttr = buildPadOpQuantizationAttr(builder, input);
321   if (quantAttr)
322     result.addAttribute("quantization_info", quantAttr);
323   result.types.push_back(outputType);
324 }
325 
326 //===----------------------------------------------------------------------===//
327 // TOSA Operator Return Type Inference.
328 //===----------------------------------------------------------------------===//
329 
getI64Values(ArrayAttr arrayAttr,SmallVector<int64_t> & values)330 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
331   for (auto it : arrayAttr) {
332     values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
333   }
334 }
335 
getF64Values(ArrayAttr arrayAttr,SmallVector<double> & values)336 static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
337   for (auto it : arrayAttr) {
338     values.push_back(it.cast<FloatAttr>().getValueAsDouble());
339   }
340 }
341 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)342 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
343     MLIRContext *context, ::llvm::Optional<Location> location,
344     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
345     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
346   ShapeAdaptor inputShape = operands.getShape(0);
347   IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
348   int32_t axisVal = axis.getValue().getSExtValue();
349 
350   if (!inputShape.hasRank()) {
351     inferredReturnShapes.push_back(ShapedTypeComponents());
352     return success();
353   }
354 
355   SmallVector<int64_t> outShape;
356   outShape.reserve(inputShape.getRank() - 1);
357   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
358     if (i == axisVal)
359       continue;
360     outShape.push_back(inputShape.getDimSize(i));
361   }
362 
363   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
364   return success();
365 }
366 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)367 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
368     MLIRContext *context, ::llvm::Optional<Location> location,
369     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
370     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
371   // Infer all dimension sizes by reducing based on inputs.
372   int32_t axis =
373       attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
374   llvm::SmallVector<int64_t> outputShape;
375   bool hasRankedInput = false;
376   for (auto operand : operands) {
377     ShapeAdaptor operandShape = operands.getShape(operand);
378     if (!operandShape.hasRank())
379       continue;
380 
381     // Copy the Operand's rank.
382     if (!hasRankedInput)
383       outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
384 
385     // Copy shapes until the dim is non-dynamic.
386     for (int i = 0, s = operandShape.getRank(); i < s; i++) {
387       if (i == axis || operandShape.isDynamicDim(i))
388         continue;
389       if (outputShape[i] == ShapedType::kDynamicSize)
390         outputShape[i] = operandShape.getDimSize(i);
391       if (outputShape[i] != operandShape.getDimSize(i))
392         return failure();
393     }
394 
395     hasRankedInput = true;
396   }
397 
398   if (!hasRankedInput) {
399     inferredReturnShapes.push_back(ShapedTypeComponents());
400     return success();
401   }
402 
403   // Determine the dimension size along the concatenation axis.
404   int concatDimSize = 0;
405   for (auto operand : operands) {
406     ShapeAdaptor operandShape = operands.getShape(operand);
407 
408     // We need to know the length of the concatenation axis of all inputs to
409     // determine the dimension size of the output shape.
410     if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
411       concatDimSize = ShapedType::kDynamicSize;
412       break;
413     }
414 
415     concatDimSize += operandShape.getDimSize(axis);
416   }
417 
418   outputShape[axis] = concatDimSize;
419 
420   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
421   return success();
422 }
423 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)424 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
425     MLIRContext *context, ::llvm::Optional<Location> location,
426     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
427     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
428   ShapeAdaptor inputShape = operands.getShape(0);
429   ShapeAdaptor weightShape = operands.getShape(1);
430   ShapeAdaptor biasShape = operands.getShape(2);
431 
432   // All shapes are dynamic.
433   SmallVector<int64_t> outShape;
434   outShape.resize(2, ShapedType::kDynamicSize);
435 
436   if (inputShape.hasRank()) {
437     outShape[0] = inputShape.getDimSize(0);
438   }
439 
440   if (weightShape.hasRank()) {
441     outShape[1] = weightShape.getDimSize(0);
442   }
443 
444   if (biasShape.hasRank()) {
445     outShape[1] = outShape[1] == ShapedType::kDynamicSize
446                       ? biasShape.getDimSize(0)
447                       : outShape[1];
448   }
449 
450   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
451   return success();
452 }
453 
verify()454 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
455 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)456 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
457     MLIRContext *context, ::llvm::Optional<Location> location,
458     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
459     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
460   ShapeAdaptor lhsShape = operands.getShape(0);
461   ShapeAdaptor rhsShape = operands.getShape(1);
462 
463   // All shapes are dynamic.
464   SmallVector<int64_t> outShape;
465   outShape.resize(3, ShapedType::kDynamicSize);
466 
467   if (lhsShape.hasRank()) {
468     outShape[0] = lhsShape.getDimSize(0);
469     outShape[1] = lhsShape.getDimSize(1);
470   }
471 
472   if (rhsShape.hasRank()) {
473     outShape[0] = outShape[0] == ShapedType::kDynamicSize
474                       ? rhsShape.getDimSize(0)
475                       : outShape[0];
476     outShape[2] = rhsShape.getDimSize(2);
477   }
478 
479   inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
480   return success();
481 }
482 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)483 LogicalResult tosa::PadOp::inferReturnTypeComponents(
484     MLIRContext *context, ::llvm::Optional<Location> location,
485     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
486     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
487   ShapeAdaptor inputShape = operands.getShape(0);
488   ShapeAdaptor paddingShape = operands.getShape(1);
489   SmallVector<int64_t> outputShape;
490 
491   // If both inputs have unknown shape, we cannot determine the shape of the
492   // output.
493   if (!inputShape.hasRank() && !paddingShape.hasRank()) {
494     inferredReturnShapes.push_back(ShapedTypeComponents());
495     return success();
496   }
497 
498   // If the input rank is unknown we can info the output rank using the padding
499   // shape's first dim.
500   if (!inputShape.hasRank()) {
501     if (paddingShape.isDynamicDim(0)) {
502       inferredReturnShapes.push_back(ShapedTypeComponents());
503       return success();
504     }
505 
506     outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
507     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
508     return success();
509   }
510 
511   DenseIntElementsAttr paddings;
512   // If the paddings value is not a constant, all dimensions must be dynamic.
513   if (!matchPattern(operands[1], m_Constant(&paddings))) {
514     outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
515     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
516     return success();
517   }
518 
519   SmallVector<int64_t> paddingValues;
520   for (auto val : paddings) {
521     paddingValues.push_back(val.getSExtValue());
522   }
523 
524   outputShape.reserve(inputShape.getRank());
525   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
526     if (inputShape.isDynamicDim(i)) {
527       outputShape.push_back(ShapedType::kDynamicSize);
528       continue;
529     }
530 
531     outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
532                           paddingValues[i * 2 + 1]);
533   }
534 
535   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
536   return success();
537 }
538 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)539 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
540     MLIRContext *context, ::llvm::Optional<Location> location,
541     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
542     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
543   ArrayAttr sizes = SliceOpAdaptor(operands, attributes).getSize();
544   SmallVector<int64_t> outputShape;
545   outputShape.reserve(sizes.size());
546   for (auto val : sizes) {
547     outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
548   }
549 
550   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
551   return success();
552 }
553 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)554 LogicalResult tosa::TableOp::inferReturnTypeComponents(
555     MLIRContext *context, ::llvm::Optional<Location> location,
556     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
557     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
558   ShapeAdaptor inputShape = operands.getShape(0);
559 
560   if (!inputShape.hasRank()) {
561     inferredReturnShapes.push_back(ShapedTypeComponents());
562     return success();
563   }
564 
565   inferredReturnShapes.resize(1);
566   inputShape.getDims(inferredReturnShapes[0]);
567   return success();
568 }
569 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)570 LogicalResult tosa::TileOp::inferReturnTypeComponents(
571     MLIRContext *context, ::llvm::Optional<Location> location,
572     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
573     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
574   TileOpAdaptor adaptor(operands, attributes);
575   ArrayAttr multiples = adaptor.getMultiples();
576   ShapeAdaptor inputShape = operands.getShape(0);
577   SmallVector<int64_t> outputShape;
578   if (!inputShape.hasRank()) {
579     outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
580     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
581     return success();
582   }
583 
584   // We need the multiple values to determine the output shape.
585   SmallVector<int64_t> multipleValues;
586   multipleValues.reserve(multiples.size());
587   for (auto val : multiples) {
588     multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
589   }
590 
591   // Any non dynamic dimension can be multiplied to a known size.
592   outputShape.reserve(multiples.size());
593   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
594     int dim = inputShape.getDimSize(i);
595     if (dim != ShapedType::kDynamicSize)
596       dim *= multipleValues[i];
597     outputShape.push_back(dim);
598   }
599 
600   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
601   return success();
602 }
603 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)604 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
605     MLIRContext *context, ::llvm::Optional<Location> location,
606     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
607     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
608   ReshapeOpAdaptor adaptor(operands, attributes);
609   ShapeAdaptor inputShape = operands.getShape(0);
610 
611   ArrayAttr newShape = adaptor.getNewShape();
612   llvm::SmallVector<int64_t> newShapeValue;
613   getI64Values(newShape, newShapeValue);
614 
615   // We cannot infer from the total number of elements so we must take the
616   // shape attribute as exact.
617   if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
618     inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
619     return success();
620   }
621 
622   // Determine the number of elements covered by the slice of all static
623   // dimensions. This allows us to infer the length of the remaining dynamic
624   // dimension.
625   int64_t numElements = inputShape.getNumElements();
626   int64_t staticMul = 1;
627   for (auto val : newShapeValue) {
628     if (val != ShapedType::kDynamicSize) {
629       staticMul *= val;
630     }
631   }
632 
633   // Determine the length of the dynamic dimension.
634   for (auto &val : newShapeValue) {
635     if (val == ShapedType::kDynamicSize)
636       val = numElements / staticMul;
637   }
638 
639   inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
640   return success();
641 }
642 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)643 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
644     MLIRContext *context, ::llvm::Optional<Location> location,
645     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
646     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
647   ShapeAdaptor inputShape = operands.getShape(0);
648   ShapeAdaptor permsShape = operands.getShape(1);
649 
650   // If input rank and permutation length is unknown, the output rank is
651   // unknown.
652   if (!inputShape.hasRank() || !permsShape.hasRank() ||
653       permsShape.isDynamicDim(0)) {
654     inferredReturnShapes.push_back(ShapedTypeComponents());
655     return success();
656   }
657 
658   // This would imply the number of permutations does not match the rank of the
659   // input which is illegal.
660   if (permsShape.getDimSize(0) != inputShape.getRank()) {
661     return failure();
662   }
663 
664   // Without the input dims we cannot determine the output dim sizes but we
665   // can determine the output rank.
666   SmallVector<int64_t> outputShape;
667   if (!inputShape.hasRank()) {
668     outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
669     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
670     return success();
671   }
672 
673   // Rank-0 means no permutations matter.
674   if (inputShape.getRank() == 0) {
675     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
676     return success();
677   }
678 
679   // Check whether the input dimensions are all the same.
680   bool allTheSame = true;
681   for (int i = 1, s = inputShape.getRank(); i < s; i++) {
682     if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
683       allTheSame = false;
684       break;
685     }
686   }
687 
688   // If all of the input dimensions are the same we don't care about the
689   // permutation.
690   if (allTheSame) {
691     outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
692     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
693     return success();
694   }
695 
696   outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
697   // If the permuations are a constant we can directly determine the output
698   // shape.
699   if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
700     outputShape.reserve(inputShape.getRank());
701     for (int i = 0, s = inputShape.getRank(); i < s; i++) {
702       outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
703     }
704   }
705 
706   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
707   return success();
708 }
709 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)710 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
711     MLIRContext *context, ::llvm::Optional<Location> location,
712     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
713     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
714   llvm::SmallVector<int64_t> outputShape;
715   outputShape.resize(3, ShapedType::kDynamicSize);
716 
717   ShapeAdaptor valuesShape = operands.getShape(0);
718   if (valuesShape.hasRank()) {
719     outputShape[0] = valuesShape.getDimSize(0);
720     outputShape[2] = valuesShape.getDimSize(2);
721   }
722 
723   ShapeAdaptor indicesShape = operands.getShape(1);
724   if (indicesShape.hasRank()) {
725     if (outputShape[0] == ShapedType::kDynamicSize)
726       outputShape[0] = indicesShape.getDimSize(0);
727     if (outputShape[1] == ShapedType::kDynamicSize)
728       outputShape[1] = indicesShape.getDimSize(1);
729   }
730 
731   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
732   return success();
733 }
734 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)735 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
736     MLIRContext *context, ::llvm::Optional<Location> location,
737     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
738     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
739   ResizeOpAdaptor adaptor(operands, attributes);
740   llvm::SmallVector<int64_t, 4> outputShape;
741   outputShape.resize(4, ShapedType::kDynamicSize);
742 
743   int32_t inHeight = ShapedType::kDynamicSize;
744   int32_t inWidth = ShapedType::kDynamicSize;
745 
746   ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
747   if (inputShape.hasRank()) {
748     outputShape[0] = inputShape.getDimSize(0);
749     outputShape[3] = inputShape.getDimSize(3);
750 
751     inHeight = inputShape.getDimSize(1);
752     inWidth = inputShape.getDimSize(2);
753   }
754 
755   int32_t shift = adaptor.getShift();
756   llvm::SmallVector<int64_t> newShape;
757   getI64Values(adaptor.getOutputSize(), newShape);
758   outputShape[1] = newShape[0];
759   outputShape[2] = newShape[1];
760 
761   llvm::SmallVector<int64_t> strideInt;
762   llvm::SmallVector<int64_t> offsetInt;
763   llvm::SmallVector<double> strideFp;
764   llvm::SmallVector<double> offsetFp;
765   getI64Values(adaptor.getOffset(), offsetInt);
766   getF64Values(adaptor.getOffsetFp(), offsetFp);
767   getI64Values(adaptor.getStride(), strideInt);
768   getF64Values(adaptor.getStrideFp(), strideFp);
769 
770   // If we have a 0 zero in integers we know that the resize indexing needs to
771   // be performed in floating point. Use the floating point varient to compute
772   // the resize shape.
773   bool fpMode = strideInt[0] == 0;
774 
775   // We can compute the output shape if attribute specifies unknown dimensions
776   // based on the offset and stride. If we perfectly line up to the last index
777   // we need to round up the size to include it.
778   if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
779     float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
780     float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
781     outputShape[1] = std::ceil(sizeFp) + round;
782   }
783 
784   if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
785     float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
786     float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
787     outputShape[2] = std::ceil(sizeFp) + round;
788   }
789 
790   if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
791     int64_t size = (inHeight - 1);
792     size = ((size << shift) - offsetInt[0]) / strideInt[0];
793     outputShape[1] = size + 1;
794   }
795 
796   if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
797     int64_t size = (inWidth - 1);
798     size = ((size << shift) - offsetInt[1]) / strideInt[1];
799     outputShape[2] = size + 1;
800   }
801 
802   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
803   return success();
804 }
805 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)806 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
807     MLIRContext *context, ::llvm::Optional<Location> location,
808     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
809     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
810   llvm::SmallVector<int64_t> outputShape;
811   outputShape.resize(3, ShapedType::kDynamicSize);
812 
813   ShapeAdaptor valuesInShape = operands.getShape(0);
814   if (valuesInShape.hasRank()) {
815     outputShape[0] = valuesInShape.getDimSize(0);
816     outputShape[1] = valuesInShape.getDimSize(1);
817     outputShape[2] = valuesInShape.getDimSize(2);
818   }
819 
820   ShapeAdaptor indicesShape = operands.getShape(1);
821   if (indicesShape.hasRank()) {
822     if (outputShape[0] == ShapedType::kDynamicSize)
823       outputShape[0] = indicesShape.getDimSize(0);
824   }
825 
826   ShapeAdaptor inputShape = operands.getShape(2);
827   if (inputShape.hasRank()) {
828     if (outputShape[0] == ShapedType::kDynamicSize)
829       outputShape[0] = inputShape.getDimSize(0);
830     if (outputShape[2] == ShapedType::kDynamicSize)
831       outputShape[2] = inputShape.getDimSize(2);
832   }
833 
834   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
835   return success();
836 }
837 
ReduceInferReturnTypes(ShapeAdaptor operandShape,IntegerAttr axis,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)838 static LogicalResult ReduceInferReturnTypes(
839     ShapeAdaptor operandShape, IntegerAttr axis,
840     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
841   if (!operandShape.hasRank()) {
842     inferredReturnShapes.push_back(ShapedTypeComponents());
843     return success();
844   }
845 
846   SmallVector<int64_t> outputShape;
847   operandShape.getDims(outputShape);
848   int64_t axisVal = axis.getValue().getSExtValue();
849   outputShape[axisVal] = 1;
850   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
851   return success();
852 }
853 
854 #define REDUCE_SHAPE_INFER(OP)                                                 \
855   LogicalResult OP::inferReturnTypeComponents(                                 \
856       MLIRContext *context, ::llvm::Optional<Location> location,               \
857       ValueShapeRange operands, DictionaryAttr attributes,                     \
858       RegionRange regions,                                                     \
859       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
860     return ReduceInferReturnTypes(operands.getShape(0),                        \
861                                   attributes.get("axis").cast<IntegerAttr>(),  \
862                                   inferredReturnShapes);                       \
863   }
864 
865 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)866 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
867 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
868 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
869 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
870 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
871 #undef REDUCE_SHAPE_INFER
872 
873 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
874                                            SmallVector<int64_t> &outShape) {
875   int64_t outRank = 0;
876   for (int i = 0, e = operands.size(); i != e; ++i) {
877     auto shape = operands.getShape(i);
878     if (!shape.hasRank()) {
879       return failure();
880     }
881     outRank = std::max<int64_t>(outRank, shape.getRank());
882   }
883 
884   outShape.resize(outRank, 1);
885 
886   for (int i = 0, e = operands.size(); i != e; ++i) {
887     auto shape = operands.getShape(i);
888     auto rankDiff = outShape.size() - shape.getRank();
889 
890     for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
891       auto dim1 = outShape[i + rankDiff];
892       auto dim2 = shape.getDimSize(i);
893       auto resolvedDim = dim1;
894 
895       if (dim1 == 1) {
896         resolvedDim = dim2;
897       } else if (dim2 == 1) {
898         resolvedDim = dim1;
899       } else if (dim1 != dim2) {
900         return failure();
901       }
902       outShape[i + rankDiff] = resolvedDim;
903     }
904   }
905 
906   return success();
907 }
908 
NAryInferReturnTypes(const ValueShapeRange & operands,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)909 static LogicalResult NAryInferReturnTypes(
910     const ValueShapeRange &operands,
911     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
912   llvm::SmallVector<int64_t> outShape;
913   if (resolveBroadcastShape(operands, outShape).failed()) {
914     inferredReturnShapes.push_back(ShapedTypeComponents());
915   } else {
916     inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
917   }
918   return success();
919 }
920 
921 #define NARY_SHAPE_INFER(OP)                                                   \
922   LogicalResult OP::inferReturnTypeComponents(                                 \
923       MLIRContext *context, ::llvm::Optional<Location> location,               \
924       ValueShapeRange operands, DictionaryAttr attributes,                     \
925       RegionRange regions,                                                     \
926       SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {           \
927     return NAryInferReturnTypes(operands, inferredReturnShapes);               \
928   }
929 
930 NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)931 NARY_SHAPE_INFER(tosa::AddOp)
932 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
933 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
934 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
935 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
936 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
937 NARY_SHAPE_INFER(tosa::CastOp)
938 NARY_SHAPE_INFER(tosa::CeilOp)
939 NARY_SHAPE_INFER(tosa::ClampOp)
940 NARY_SHAPE_INFER(tosa::ClzOp)
941 NARY_SHAPE_INFER(tosa::DivOp)
942 NARY_SHAPE_INFER(tosa::EqualOp)
943 NARY_SHAPE_INFER(tosa::ExpOp)
944 NARY_SHAPE_INFER(tosa::FloorOp)
945 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
946 NARY_SHAPE_INFER(tosa::GreaterOp)
947 NARY_SHAPE_INFER(tosa::IdentityOp)
948 NARY_SHAPE_INFER(tosa::LogOp)
949 NARY_SHAPE_INFER(tosa::LogicalAndOp)
950 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
951 NARY_SHAPE_INFER(tosa::LogicalNotOp)
952 NARY_SHAPE_INFER(tosa::LogicalOrOp)
953 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
954 NARY_SHAPE_INFER(tosa::LogicalXorOp)
955 NARY_SHAPE_INFER(tosa::MaximumOp)
956 NARY_SHAPE_INFER(tosa::MinimumOp)
957 NARY_SHAPE_INFER(tosa::MulOp)
958 NARY_SHAPE_INFER(tosa::NegateOp)
959 NARY_SHAPE_INFER(tosa::PowOp)
960 NARY_SHAPE_INFER(tosa::ReciprocalOp)
961 NARY_SHAPE_INFER(tosa::ReluNOp)
962 NARY_SHAPE_INFER(tosa::RescaleOp)
963 NARY_SHAPE_INFER(tosa::ReverseOp)
964 NARY_SHAPE_INFER(tosa::RsqrtOp)
965 NARY_SHAPE_INFER(tosa::SelectOp)
966 NARY_SHAPE_INFER(tosa::SubOp)
967 NARY_SHAPE_INFER(tosa::TanhOp)
968 NARY_SHAPE_INFER(tosa::SigmoidOp)
969 #undef PRED_SHAPE_INFER
970 
971 static LogicalResult poolingInferReturnTypes(
972     const ValueShapeRange &operands, DictionaryAttr attributes,
973     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
974   ShapeAdaptor inputShape = operands.getShape(0);
975   llvm::SmallVector<int64_t> outputShape;
976   outputShape.resize(4, -1);
977 
978   // We only know the rank if the input type is unranked.
979   if (!inputShape) {
980     inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
981     return success();
982   }
983 
984   // Batch and number of channels are identical for pooling layer.
985   outputShape[0] = inputShape.getDimSize(0);
986   outputShape[3] = inputShape.getDimSize(3);
987 
988   int32_t height = inputShape.getDimSize(1);
989   int32_t width = inputShape.getDimSize(2);
990 
991   llvm::SmallVector<int64_t> kernel;
992   llvm::SmallVector<int64_t> stride;
993   llvm::SmallVector<int64_t> pad;
994 
995   getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
996   getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
997   getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
998 
999   if (height != -1) {
1000     int32_t padded = height + pad[0] + pad[1] - kernel[0];
1001     outputShape[1] = padded / stride[0] + 1;
1002   }
1003 
1004   if (width != -1) {
1005     int32_t padded = width + pad[2] + pad[3] - kernel[1];
1006     outputShape[2] = padded / stride[1] + 1;
1007   }
1008 
1009   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1010   return success();
1011 }
1012 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1013 LogicalResult Conv2DOp::inferReturnTypeComponents(
1014     MLIRContext *context, ::llvm::Optional<Location> location,
1015     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1016     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1017   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1018   Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1019 
1020   int32_t inputWidth = ShapedType::kDynamicSize;
1021   int32_t inputHeight = ShapedType::kDynamicSize;
1022   int32_t weightWidth = ShapedType::kDynamicSize;
1023   int32_t weightHeight = ShapedType::kDynamicSize;
1024 
1025   // Input shape describes input width/height and batch.
1026 
1027   ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1028   if (inputShape.hasRank()) {
1029     outputShape[0] = inputShape.getDimSize(0);
1030     inputHeight = inputShape.getDimSize(1);
1031     inputWidth = inputShape.getDimSize(2);
1032   }
1033 
1034   // Weight shapes describes the filter width/height and the output channels.
1035   ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1036   if (weightShape.hasRank()) {
1037     outputShape[3] = weightShape.getDimSize(0);
1038     weightHeight = weightShape.getDimSize(1);
1039     weightWidth = weightShape.getDimSize(2);
1040   }
1041 
1042   // Bias shape can describe the output channels.
1043   ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1044   if (biasShape.hasRank()) {
1045     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1046                          ? biasShape.getDimSize(0)
1047                          : outputShape[3];
1048   }
1049 
1050   llvm::SmallVector<int64_t> dilation;
1051   llvm::SmallVector<int64_t> padding;
1052   llvm::SmallVector<int64_t> stride;
1053 
1054   getI64Values(adaptor.getDilation(), dilation);
1055   getI64Values(adaptor.getPad(), padding);
1056   getI64Values(adaptor.getStride(), stride);
1057 
1058   if (!ShapedType::isDynamic(inputHeight) &&
1059       !ShapedType::isDynamic(weightHeight)) {
1060     int32_t inputSize = inputHeight + padding[0] + padding[1];
1061     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1062     int32_t unstridedResult = inputSize - filterSize + 1;
1063     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1064   }
1065 
1066   if (!ShapedType::isDynamic(inputWidth) &&
1067       !ShapedType::isDynamic(weightWidth)) {
1068     int32_t inputSize = inputWidth + padding[2] + padding[3];
1069     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1070     int32_t unstridedResult = inputSize - filterSize + 1;
1071     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1072   }
1073 
1074   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1075   return success();
1076 }
1077 
verify()1078 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1079 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1080 LogicalResult Conv3DOp::inferReturnTypeComponents(
1081     MLIRContext *context, ::llvm::Optional<Location> location,
1082     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1083     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1084   llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1085   Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1086 
1087   int32_t inputWidth = ShapedType::kDynamicSize;
1088   int32_t inputHeight = ShapedType::kDynamicSize;
1089   int32_t inputDepth = ShapedType::kDynamicSize;
1090 
1091   int32_t weightWidth = ShapedType::kDynamicSize;
1092   int32_t weightHeight = ShapedType::kDynamicSize;
1093   int32_t weightDepth = ShapedType::kDynamicSize;
1094 
1095   // Input shape describes input width/height and batch.
1096   ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1097   if (inputShape.hasRank()) {
1098     outputShape[0] = inputShape.getDimSize(0);
1099     inputHeight = inputShape.getDimSize(1);
1100     inputWidth = inputShape.getDimSize(2);
1101     inputDepth = inputShape.getDimSize(3);
1102   }
1103 
1104   // Weight shapes describes the filter width/height and the output channels.
1105   ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1106   if (weightShape.hasRank()) {
1107     outputShape[4] = weightShape.getDimSize(0);
1108     weightHeight = weightShape.getDimSize(1);
1109     weightWidth = weightShape.getDimSize(2);
1110     weightDepth = weightShape.getDimSize(3);
1111   }
1112 
1113   // Bias shape can describe the output channels.
1114   ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1115   if (biasShape.hasRank()) {
1116     outputShape[4] =
1117         (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1118   }
1119 
1120   llvm::SmallVector<int64_t> dilation;
1121   llvm::SmallVector<int64_t> padding;
1122   llvm::SmallVector<int64_t> stride;
1123 
1124   getI64Values(adaptor.getDilation(), dilation);
1125   getI64Values(adaptor.getPad(), padding);
1126   getI64Values(adaptor.getStride(), stride);
1127 
1128   if (!ShapedType::isDynamic(inputHeight) &&
1129       !ShapedType::isDynamic(weightHeight)) {
1130     int32_t inputSize = inputHeight + padding[0] + padding[1];
1131     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1132     int32_t unstridedResult = inputSize - filterSize + 1;
1133     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1134   }
1135 
1136   if (!ShapedType::isDynamic(inputWidth) &&
1137       !ShapedType::isDynamic(weightWidth)) {
1138     int32_t inputSize = inputWidth + padding[2] + padding[3];
1139     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1140     int32_t unstridedResult = inputSize - filterSize + 1;
1141     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1142   }
1143 
1144   if (!ShapedType::isDynamic(inputDepth) &&
1145       !ShapedType::isDynamic(weightDepth)) {
1146     int32_t inputSize = inputDepth + padding[4] + padding[5];
1147     int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1148     int32_t unstridedResult = inputSize - filterSize + 1;
1149     outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1150   }
1151 
1152   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1153   return success();
1154 }
1155 
verify()1156 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1157 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1158 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1159     MLIRContext *context, ::llvm::Optional<Location> location,
1160     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1161     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1162   return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1163 }
1164 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1165 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1166     MLIRContext *context, ::llvm::Optional<Location> location,
1167     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1168     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1169   return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1170 }
1171 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1172 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1173     MLIRContext *context, ::llvm::Optional<Location> location,
1174     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1175     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1176   llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1177   DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1178 
1179   int32_t inputWidth = ShapedType::kDynamicSize;
1180   int32_t inputHeight = ShapedType::kDynamicSize;
1181   int32_t inputChannels = ShapedType::kDynamicSize;
1182 
1183   int32_t weightWidth = ShapedType::kDynamicSize;
1184   int32_t weightHeight = ShapedType::kDynamicSize;
1185   int32_t depthChannels = ShapedType::kDynamicSize;
1186 
1187   // Input shape describes input width/height and batch.
1188   ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1189   if (inputShape.hasRank()) {
1190     outputShape[0] = inputShape.getDimSize(0);
1191     inputHeight = inputShape.getDimSize(1);
1192     inputWidth = inputShape.getDimSize(2);
1193     inputChannels = inputShape.getDimSize(3);
1194   }
1195 
1196   // Weight shapes describes the filter width/height and the output channels.
1197   ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1198   if (weightShape.hasRank()) {
1199     weightHeight = weightShape.getDimSize(0);
1200     weightWidth = weightShape.getDimSize(1);
1201     inputChannels = ShapedType::isDynamic(inputChannels)
1202                         ? weightShape.getDimSize(2)
1203                         : inputChannels;
1204     depthChannels = weightShape.getDimSize(3);
1205   }
1206 
1207   // If both inputChannels and depthChannels are available we can determine
1208   // the output channels.
1209   if (!ShapedType::isDynamic(inputChannels) &&
1210       !ShapedType::isDynamic(depthChannels)) {
1211     outputShape[3] = inputChannels * depthChannels;
1212   }
1213 
1214   // Bias shape can describe the output channels.
1215   ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1216   if (biasShape.hasRank()) {
1217     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1218                          ? biasShape.getDimSize(0)
1219                          : outputShape[3];
1220   }
1221 
1222   llvm::SmallVector<int64_t> dilation;
1223   llvm::SmallVector<int64_t> padding;
1224   llvm::SmallVector<int64_t> stride;
1225 
1226   getI64Values(adaptor.getDilation(), dilation);
1227   getI64Values(adaptor.getPad(), padding);
1228   getI64Values(adaptor.getStride(), stride);
1229 
1230   if (!ShapedType::isDynamic(inputHeight) &&
1231       !ShapedType::isDynamic(weightHeight)) {
1232     int32_t inputSize = inputHeight + padding[0] + padding[1];
1233     int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1234     int32_t unstridedResult = inputSize - filterSize + 1;
1235     outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1236   }
1237 
1238   if (!ShapedType::isDynamic(inputWidth) &&
1239       !ShapedType::isDynamic(weightWidth)) {
1240     int32_t inputSize = inputWidth + padding[2] + padding[3];
1241     int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1242     int32_t unstridedResult = inputSize - filterSize + 1;
1243     outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1244   }
1245 
1246   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1247   return success();
1248 }
1249 
verify()1250 LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1251 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1252 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1253     MLIRContext *context, ::llvm::Optional<Location> location,
1254     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1255     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1256   TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1257   llvm::SmallVector<int64_t> outputShape;
1258   getI64Values(adaptor.getOutShape(), outputShape);
1259 
1260   int32_t inputWidth = ShapedType::kDynamicSize;
1261   int32_t inputHeight = ShapedType::kDynamicSize;
1262   int32_t weightWidth = ShapedType::kDynamicSize;
1263   int32_t weightHeight = ShapedType::kDynamicSize;
1264 
1265   // Input shape describes input width/height and batch.
1266   ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1267   if (inputShape.hasRank()) {
1268     outputShape[0] = ShapedType::isDynamic(outputShape[0])
1269                          ? inputShape.getDimSize(0)
1270                          : outputShape[0];
1271     inputHeight = inputShape.getDimSize(1);
1272     inputWidth = inputShape.getDimSize(2);
1273   }
1274 
1275   // Weight shapes describes the filter width/height and the output channels.
1276   ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
1277   if (weightShape.hasRank()) {
1278     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1279                          ? weightShape.getDimSize(0)
1280                          : outputShape[3];
1281     weightHeight = weightShape.getDimSize(1);
1282     weightWidth = weightShape.getDimSize(2);
1283   }
1284 
1285   // Bias shape can describe the output channels.
1286   ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
1287   if (biasShape.hasRank()) {
1288     outputShape[3] = ShapedType::isDynamic(outputShape[3])
1289                          ? biasShape.getDimSize(0)
1290                          : outputShape[3];
1291   }
1292 
1293   llvm::SmallVector<int64_t> padding;
1294   llvm::SmallVector<int64_t> stride;
1295 
1296   getI64Values(adaptor.getOutPad(), padding);
1297   getI64Values(adaptor.getStride(), stride);
1298 
1299   if (!ShapedType::isDynamic(inputHeight) &&
1300       !ShapedType::isDynamic(weightHeight)) {
1301     int32_t calculateSize =
1302         (inputHeight - 1) * stride[0] - padding[0] - padding[1] + weightHeight;
1303     outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1304   }
1305 
1306   if (!ShapedType::isDynamic(inputWidth) &&
1307       !ShapedType::isDynamic(weightWidth)) {
1308     int32_t calculateSize =
1309         (inputWidth - 1) * stride[1] - padding[2] - padding[3] + weightWidth;
1310     outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1311   }
1312 
1313   inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1314   return success();
1315 }
1316 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1317 LogicalResult IfOp::inferReturnTypeComponents(
1318     MLIRContext *context, ::llvm::Optional<Location> location,
1319     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1320     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1321   llvm::SmallVector<tosa::YieldOp> yieldOps;
1322   for (Region *region : regions) {
1323     for (auto &block : *region)
1324       if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1325         yieldOps.push_back(returnOp);
1326   }
1327 
1328   if (yieldOps.empty())
1329     return failure();
1330 
1331   // Get the initial type information for the yield op.
1332   llvm::SmallVector<ValueKnowledge> resultKnowledge;
1333   resultKnowledge.reserve(yieldOps.front().getNumOperands());
1334   for (auto operand : yieldOps.front().getOperands()) {
1335     resultKnowledge.push_back(
1336         ValueKnowledge::getKnowledgeFromType(operand.getType()));
1337   }
1338 
1339   for (auto yieldOp : yieldOps) {
1340     if (resultKnowledge.size() != yieldOp.getNumOperands())
1341       return failure();
1342 
1343     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1344       int32_t index = it.index();
1345       auto meet = ValueKnowledge::meet(
1346           resultKnowledge[index],
1347           ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1348       if (!meet)
1349         continue;
1350       resultKnowledge[index] = meet;
1351     }
1352   }
1353 
1354   for (const ValueKnowledge &result : resultKnowledge) {
1355     inferredReturnShapes.push_back(result.getShapedTypeComponents());
1356   }
1357 
1358   return success();
1359 }
1360 
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1361 LogicalResult WhileOp::inferReturnTypeComponents(
1362     MLIRContext *context, ::llvm::Optional<Location> location,
1363     ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1364     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1365   llvm::SmallVector<tosa::YieldOp> yieldOps;
1366   for (auto &block : *regions[1])
1367     if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1368       yieldOps.push_back(returnOp);
1369 
1370   // TOSA's while must have a tosa.yield as its terminator. If not found this
1371   // tosa.while is invalid.
1372   if (yieldOps.empty())
1373     return failure();
1374 
1375   // Get the initial type information from the operand types.
1376   llvm::SmallVector<ValueKnowledge> resultKnowledge;
1377   resultKnowledge.reserve(yieldOps.front().getNumOperands());
1378   for (auto operand : yieldOps.front().getOperands()) {
1379     resultKnowledge.push_back(
1380         ValueKnowledge::getKnowledgeFromType(operand.getType()));
1381   }
1382 
1383   for (auto yieldOp : yieldOps) {
1384     if (resultKnowledge.size() != yieldOp.getNumOperands())
1385       return failure();
1386 
1387     for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1388       int32_t index = it.index();
1389       if (auto meet = ValueKnowledge::meet(
1390               resultKnowledge[index],
1391               ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1392         resultKnowledge[index] = meet;
1393       }
1394     }
1395   }
1396 
1397   for (const ValueKnowledge &result : resultKnowledge) {
1398     inferredReturnShapes.push_back(result.getShapedTypeComponents());
1399   }
1400 
1401   return success();
1402 }
1403 
1404 //===----------------------------------------------------------------------===//
1405 // TOSA Attribute Definitions.
1406 //===----------------------------------------------------------------------===//
1407 
1408 #define GET_ATTRDEF_CLASSES
1409 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1410 
1411 //===----------------------------------------------------------------------===//
1412 // TOSA Operator Definitions.
1413 //===----------------------------------------------------------------------===//
1414 
1415 #define GET_OP_CLASSES
1416 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1417