1 //===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // \file
10 // This file implements the TOSA Specification:
11 // https://developer.mlplatform.org/w/tosa/
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
16 #include "mlir/Dialect/Quant/QuantOps.h"
17 #include "mlir/Dialect/Tensor/IR/Tensor.h"
18 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
19 #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/Matchers.h"
23 #include "mlir/IR/PatternMatch.h"
24 #include "mlir/Transforms/InliningUtils.h"
25 #include "llvm/ADT/DenseMap.h"
26 #include "llvm/ADT/TypeSwitch.h"
27
28 using namespace mlir;
29 using namespace mlir::tosa;
30
31 #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
32
33 //===----------------------------------------------------------------------===//
34 // Tosa dialect interface includes.
35 //===----------------------------------------------------------------------===//
36
37 #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
38
39 namespace {
40 //===----------------------------------------------------------------------===//
41 // Dialect Function Inliner Interface.
42 //===----------------------------------------------------------------------===//
43 struct TosaInlinerInterface : public DialectInlinerInterface {
44 using DialectInlinerInterface::DialectInlinerInterface;
45
46 //===--------------------------------------------------------------------===//
47 // Analysis Hooks.
48 //===--------------------------------------------------------------------===//
49
50 /// All operations can be inlined by default.
isLegalToInline__anon084804d70111::TosaInlinerInterface51 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
52 BlockAndValueMapping &map) const final {
53 return true;
54 }
55
56 /// All regions with If and While parent operators can be inlined.
isLegalToInline__anon084804d70111::TosaInlinerInterface57 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
58 BlockAndValueMapping &map) const final {
59 return (isa<tosa::IfOp>(dest->getParentOp()) ||
60 isa<tosa::WhileOp>(dest->getParentOp()));
61 }
62 };
63 } // namespace
64
65 //===----------------------------------------------------------------------===//
66 // TOSA control flow support.
67 //===----------------------------------------------------------------------===//
68
69 /// Returns the while loop body.
getLoopBody()70 Region &tosa::WhileOp::getLoopBody() { return getBody(); }
71
72 //===----------------------------------------------------------------------===//
73 // Tosa dialect initialization.
74 //===----------------------------------------------------------------------===//
75
initialize()76 void TosaDialect::initialize() {
77 addOperations<
78 #define GET_OP_LIST
79 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
80 >();
81 addAttributes<
82 #define GET_ATTRDEF_LIST
83 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
84 >();
85 addInterfaces<TosaInlinerInterface>();
86 }
87
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)88 Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
89 Type type, Location loc) {
90 // Tosa dialect constants only support ElementsAttr unlike standard dialect
91 // constant which supports all attributes.
92 if (value.isa<ElementsAttr>())
93 return builder.create<tosa::ConstOp>(loc, type, value.cast<ElementsAttr>());
94 return nullptr;
95 }
96
97 //===----------------------------------------------------------------------===//
98 // TOSA Operator Verifiers.
99 //===----------------------------------------------------------------------===//
100
101 template <typename T>
verifyConvOp(T op)102 static LogicalResult verifyConvOp(T op) {
103 // All TOSA conv ops have an input() and weight().
104 auto inputType =
105 op.getInput().getType().template dyn_cast<RankedTensorType>();
106 auto weightType =
107 op.getWeight().getType().template dyn_cast<RankedTensorType>();
108
109 // Must be ranked tensor types
110 if (!inputType) {
111 op.emitOpError("expect a ranked tensor for input, got ") << op.getInput();
112 return failure();
113 }
114 if (!weightType) {
115 op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
116 return failure();
117 }
118
119 auto inputEType = inputType.getElementType();
120 auto weightEType = weightType.getElementType();
121
122 bool inputIsQuant = !inputEType.template isa<FloatType>();
123 bool weightIsQuant = !weightEType.template isa<FloatType>();
124
125 // Either both must be quantized or both unquantized.
126 if (inputIsQuant != weightIsQuant) {
127 op.emitOpError(
128 "expect both input and weight to be float or not together, got ")
129 << inputEType << " and " << weightEType;
130 return failure();
131 }
132
133 // Quantized type must have constructed the quantizationattr, and unquantized
134 // types should not have a quantizationattr.
135 if ((inputIsQuant && !op.getQuantizationInfo()) ||
136 (!inputIsQuant && op.getQuantizationInfo())) {
137 op.emitOpError("quantizationattr is required for quantized type, and not "
138 "allowed for float type");
139 return failure();
140 }
141
142 return success();
143 }
144
verify()145 LogicalResult tosa::AvgPool2dOp::verify() {
146 auto inputETy = getInput().getType().cast<ShapedType>().getElementType();
147 auto resultETy = getType().cast<ShapedType>().getElementType();
148
149 if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
150 inputETy = quantType.getStorageType();
151
152 if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
153 resultETy = quantType.getStorageType();
154
155 if (inputETy.isF32() && resultETy.isF32())
156 return success();
157 if (inputETy.isInteger(8) && resultETy.isInteger(8))
158 return success();
159 if (inputETy.isInteger(16) && resultETy.isInteger(16))
160 return success();
161
162 return emitOpError("input/output element types are incompatible.");
163 }
164
165 //===----------------------------------------------------------------------===//
166 // TOSA Operator Quantization Builders.
167 //===----------------------------------------------------------------------===//
168
169 /// This builder is called on all convolution operators except TransposeConv,
170 /// which has specialized output shape semantics. The builder also defines the
171 /// bitwidth of the output given the bit width of the input & weight content.
buildConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr pad,ArrayAttr stride,ArrayAttr dilation)172 static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
173 Type outputType, Value input, Value weight,
174 Value bias, ArrayAttr pad,
175 ArrayAttr stride, ArrayAttr dilation) {
176
177 result.addOperands({input, weight, bias});
178 result.addAttribute("pad", pad);
179 result.addAttribute("stride", stride);
180 result.addAttribute("dilation", dilation);
181
182 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
183 if (quantAttr) {
184 result.addAttribute("quantization_info", quantAttr);
185 result.addTypes(
186 buildConvOpResultTypeInfo(builder, outputType, input, weight));
187 } else {
188 result.addTypes(outputType);
189 }
190 }
191
192 /// Handles tosa.transpose_conv2d which has outpad and output shape attributes.
buildTransConvOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias,ArrayAttr outpad,ArrayAttr stride,ArrayAttr outputShape)193 static void buildTransConvOpWithQuantInfo(OpBuilder &builder,
194 OperationState &result,
195 Type outputType, Value input,
196 Value weight, Value bias,
197 ArrayAttr outpad, ArrayAttr stride,
198 ArrayAttr outputShape) {
199 result.addOperands({input, weight, bias});
200 result.addAttribute("out_pad", outpad);
201 result.addAttribute("stride", stride);
202 result.addAttribute("out_shape", outputShape);
203 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
204
205 if (quantAttr) {
206 result.addAttribute("quantization_info", quantAttr);
207 result.addTypes(
208 buildConvOpResultTypeInfo(builder, outputType, input, weight));
209 } else {
210 result.addTypes(outputType);
211 }
212 }
213
214 /// The tosa.fully_connected op has its own builder as it does not have
215 /// strides/dilation/padding.
buildFCOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value weight,Value bias)216 static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result,
217 Type outputType, Value input, Value weight,
218 Value bias) {
219
220 result.addOperands({input, weight, bias});
221 auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
222 if (quantAttr) {
223 result.addAttribute("quantization_info", quantAttr);
224 result.addTypes(
225 buildConvOpResultTypeInfo(builder, outputType, input, weight));
226 } else {
227 result.addTypes(outputType);
228 }
229 }
230
231 /// The tosa.matmul op is also intended to be generated where a fully_connected
232 /// op must be constructed where the weight is not a constant. In this case,
233 /// the fully_connected op must be expressed using matmul.
234 /// TODO: Add link to the leglization document explaining this.
buildMatMulOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value a,Value b)235 static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
236 OperationState &result, Type outputType,
237 Value a, Value b) {
238 result.addOperands({a, b});
239 auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
240
241 if (quantAttr) {
242 result.addAttribute("quantization_info", quantAttr);
243
244 auto inputType = a.getType().dyn_cast<ShapedType>();
245 assert(inputType && "Input must be a shaped tensor type!");
246
247 auto inputQType = inputType.getElementType()
248 .dyn_cast<mlir::quant::UniformQuantizedType>();
249 assert(inputQType && "Tensor must have quantized datatype!");
250
251 unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
252
253 auto outputShapedType = outputType.dyn_cast<ShapedType>();
254 assert(outputShapedType && "Output must be a shaped type");
255
256 IntegerType accElementType;
257 if (inputBits == 16)
258 accElementType = builder.getIntegerType(48);
259 else
260 accElementType = builder.getI32Type();
261 auto accType = outputShapedType.clone(accElementType);
262 result.addTypes(accType);
263 } else {
264 result.addTypes(outputType);
265 }
266 }
267
268 /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
269 /// but avg_pool operator has its own builder as it has additional parameters
270 /// not part of the unary ops.
buildAvgPool2dOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,ArrayAttr kernel,ArrayAttr stride,ArrayAttr pad)271 static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder,
272 OperationState &result,
273 Type outputType, Value input,
274 ArrayAttr kernel, ArrayAttr stride,
275 ArrayAttr pad) {
276 result.addOperands(input);
277 result.addAttribute("kernel", kernel);
278 result.addAttribute("stride", stride);
279 result.addAttribute("pad", pad);
280 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
281 if (quantAttr)
282 result.addAttribute("quantization_info", quantAttr);
283 result.types.push_back(outputType);
284 }
285
286 /// This builder is called on single-parameter unary operators that have scale
287 /// relationship between their input and output, expressed by the
288 /// UnaryOpQuantizationAttr.
buildUnaryOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input)289 static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
290 OperationState &result, Type outputType,
291 Value input) {
292 result.addOperands(input);
293 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
294 if (quantAttr)
295 result.addAttribute("quantization_info", quantAttr);
296 result.types.push_back(outputType);
297 }
298
299 /// This builder is called on TOSA pad operator that needs to create its own
300 /// OptionalAttr quantization_attr parameter to scale the padding values
301 /// correctly. No pad_const is interpreted as zero-padding.
buildPadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings)302 static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
303 Type outputType, Value input,
304 Value paddings) {
305 result.addOperands({input, paddings});
306 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
307 if (quantAttr)
308 result.addAttribute("quantization_info", quantAttr);
309 result.types.push_back(outputType);
310 }
311
312 /// This builder is called on TOSA pad operator when an explicit pad_const
313 /// value is passed in. It also optionally constructs quantization_attr.
buildExplicitValuePadOpWithQuantInfo(OpBuilder & builder,OperationState & result,Type outputType,Value input,Value paddings,Value padConst)314 static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder,
315 OperationState &result,
316 Type outputType, Value input,
317 Value paddings,
318 Value padConst) {
319 result.addOperands({input, paddings, padConst});
320 auto quantAttr = buildPadOpQuantizationAttr(builder, input);
321 if (quantAttr)
322 result.addAttribute("quantization_info", quantAttr);
323 result.types.push_back(outputType);
324 }
325
326 //===----------------------------------------------------------------------===//
327 // TOSA Operator Return Type Inference.
328 //===----------------------------------------------------------------------===//
329
getI64Values(ArrayAttr arrayAttr,SmallVector<int64_t> & values)330 static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
331 for (auto it : arrayAttr) {
332 values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
333 }
334 }
335
getF64Values(ArrayAttr arrayAttr,SmallVector<double> & values)336 static void getF64Values(ArrayAttr arrayAttr, SmallVector<double> &values) {
337 for (auto it : arrayAttr) {
338 values.push_back(it.cast<FloatAttr>().getValueAsDouble());
339 }
340 }
341
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)342 LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
343 MLIRContext *context, ::llvm::Optional<Location> location,
344 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
345 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
346 ShapeAdaptor inputShape = operands.getShape(0);
347 IntegerAttr axis = attributes.get("axis").cast<IntegerAttr>();
348 int32_t axisVal = axis.getValue().getSExtValue();
349
350 if (!inputShape.hasRank()) {
351 inferredReturnShapes.push_back(ShapedTypeComponents());
352 return success();
353 }
354
355 SmallVector<int64_t> outShape;
356 outShape.reserve(inputShape.getRank() - 1);
357 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
358 if (i == axisVal)
359 continue;
360 outShape.push_back(inputShape.getDimSize(i));
361 }
362
363 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
364 return success();
365 }
366
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)367 LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
368 MLIRContext *context, ::llvm::Optional<Location> location,
369 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
370 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
371 // Infer all dimension sizes by reducing based on inputs.
372 int32_t axis =
373 attributes.get("axis").cast<IntegerAttr>().getValue().getSExtValue();
374 llvm::SmallVector<int64_t> outputShape;
375 bool hasRankedInput = false;
376 for (auto operand : operands) {
377 ShapeAdaptor operandShape = operands.getShape(operand);
378 if (!operandShape.hasRank())
379 continue;
380
381 // Copy the Operand's rank.
382 if (!hasRankedInput)
383 outputShape.resize(operandShape.getRank(), ShapedType::kDynamicSize);
384
385 // Copy shapes until the dim is non-dynamic.
386 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
387 if (i == axis || operandShape.isDynamicDim(i))
388 continue;
389 if (outputShape[i] == ShapedType::kDynamicSize)
390 outputShape[i] = operandShape.getDimSize(i);
391 if (outputShape[i] != operandShape.getDimSize(i))
392 return failure();
393 }
394
395 hasRankedInput = true;
396 }
397
398 if (!hasRankedInput) {
399 inferredReturnShapes.push_back(ShapedTypeComponents());
400 return success();
401 }
402
403 // Determine the dimension size along the concatenation axis.
404 int concatDimSize = 0;
405 for (auto operand : operands) {
406 ShapeAdaptor operandShape = operands.getShape(operand);
407
408 // We need to know the length of the concatenation axis of all inputs to
409 // determine the dimension size of the output shape.
410 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
411 concatDimSize = ShapedType::kDynamicSize;
412 break;
413 }
414
415 concatDimSize += operandShape.getDimSize(axis);
416 }
417
418 outputShape[axis] = concatDimSize;
419
420 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
421 return success();
422 }
423
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)424 LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
425 MLIRContext *context, ::llvm::Optional<Location> location,
426 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
427 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
428 ShapeAdaptor inputShape = operands.getShape(0);
429 ShapeAdaptor weightShape = operands.getShape(1);
430 ShapeAdaptor biasShape = operands.getShape(2);
431
432 // All shapes are dynamic.
433 SmallVector<int64_t> outShape;
434 outShape.resize(2, ShapedType::kDynamicSize);
435
436 if (inputShape.hasRank()) {
437 outShape[0] = inputShape.getDimSize(0);
438 }
439
440 if (weightShape.hasRank()) {
441 outShape[1] = weightShape.getDimSize(0);
442 }
443
444 if (biasShape.hasRank()) {
445 outShape[1] = outShape[1] == ShapedType::kDynamicSize
446 ? biasShape.getDimSize(0)
447 : outShape[1];
448 }
449
450 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
451 return success();
452 }
453
verify()454 LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
455
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)456 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
457 MLIRContext *context, ::llvm::Optional<Location> location,
458 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
459 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
460 ShapeAdaptor lhsShape = operands.getShape(0);
461 ShapeAdaptor rhsShape = operands.getShape(1);
462
463 // All shapes are dynamic.
464 SmallVector<int64_t> outShape;
465 outShape.resize(3, ShapedType::kDynamicSize);
466
467 if (lhsShape.hasRank()) {
468 outShape[0] = lhsShape.getDimSize(0);
469 outShape[1] = lhsShape.getDimSize(1);
470 }
471
472 if (rhsShape.hasRank()) {
473 outShape[0] = outShape[0] == ShapedType::kDynamicSize
474 ? rhsShape.getDimSize(0)
475 : outShape[0];
476 outShape[2] = rhsShape.getDimSize(2);
477 }
478
479 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
480 return success();
481 }
482
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)483 LogicalResult tosa::PadOp::inferReturnTypeComponents(
484 MLIRContext *context, ::llvm::Optional<Location> location,
485 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
486 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
487 ShapeAdaptor inputShape = operands.getShape(0);
488 ShapeAdaptor paddingShape = operands.getShape(1);
489 SmallVector<int64_t> outputShape;
490
491 // If both inputs have unknown shape, we cannot determine the shape of the
492 // output.
493 if (!inputShape.hasRank() && !paddingShape.hasRank()) {
494 inferredReturnShapes.push_back(ShapedTypeComponents());
495 return success();
496 }
497
498 // If the input rank is unknown we can info the output rank using the padding
499 // shape's first dim.
500 if (!inputShape.hasRank()) {
501 if (paddingShape.isDynamicDim(0)) {
502 inferredReturnShapes.push_back(ShapedTypeComponents());
503 return success();
504 }
505
506 outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamicSize);
507 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
508 return success();
509 }
510
511 DenseIntElementsAttr paddings;
512 // If the paddings value is not a constant, all dimensions must be dynamic.
513 if (!matchPattern(operands[1], m_Constant(&paddings))) {
514 outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
515 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
516 return success();
517 }
518
519 SmallVector<int64_t> paddingValues;
520 for (auto val : paddings) {
521 paddingValues.push_back(val.getSExtValue());
522 }
523
524 outputShape.reserve(inputShape.getRank());
525 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
526 if (inputShape.isDynamicDim(i)) {
527 outputShape.push_back(ShapedType::kDynamicSize);
528 continue;
529 }
530
531 outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] +
532 paddingValues[i * 2 + 1]);
533 }
534
535 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
536 return success();
537 }
538
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)539 LogicalResult tosa::SliceOp::inferReturnTypeComponents(
540 MLIRContext *context, ::llvm::Optional<Location> location,
541 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
542 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
543 ArrayAttr sizes = SliceOpAdaptor(operands, attributes).getSize();
544 SmallVector<int64_t> outputShape;
545 outputShape.reserve(sizes.size());
546 for (auto val : sizes) {
547 outputShape.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
548 }
549
550 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
551 return success();
552 }
553
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)554 LogicalResult tosa::TableOp::inferReturnTypeComponents(
555 MLIRContext *context, ::llvm::Optional<Location> location,
556 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
557 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
558 ShapeAdaptor inputShape = operands.getShape(0);
559
560 if (!inputShape.hasRank()) {
561 inferredReturnShapes.push_back(ShapedTypeComponents());
562 return success();
563 }
564
565 inferredReturnShapes.resize(1);
566 inputShape.getDims(inferredReturnShapes[0]);
567 return success();
568 }
569
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)570 LogicalResult tosa::TileOp::inferReturnTypeComponents(
571 MLIRContext *context, ::llvm::Optional<Location> location,
572 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
573 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
574 TileOpAdaptor adaptor(operands, attributes);
575 ArrayAttr multiples = adaptor.getMultiples();
576 ShapeAdaptor inputShape = operands.getShape(0);
577 SmallVector<int64_t> outputShape;
578 if (!inputShape.hasRank()) {
579 outputShape.resize(multiples.size(), ShapedType::kDynamicSize);
580 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
581 return success();
582 }
583
584 // We need the multiple values to determine the output shape.
585 SmallVector<int64_t> multipleValues;
586 multipleValues.reserve(multiples.size());
587 for (auto val : multiples) {
588 multipleValues.push_back(val.cast<IntegerAttr>().getValue().getSExtValue());
589 }
590
591 // Any non dynamic dimension can be multiplied to a known size.
592 outputShape.reserve(multiples.size());
593 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
594 int dim = inputShape.getDimSize(i);
595 if (dim != ShapedType::kDynamicSize)
596 dim *= multipleValues[i];
597 outputShape.push_back(dim);
598 }
599
600 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
601 return success();
602 }
603
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)604 LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
605 MLIRContext *context, ::llvm::Optional<Location> location,
606 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
607 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
608 ReshapeOpAdaptor adaptor(operands, attributes);
609 ShapeAdaptor inputShape = operands.getShape(0);
610
611 ArrayAttr newShape = adaptor.getNewShape();
612 llvm::SmallVector<int64_t> newShapeValue;
613 getI64Values(newShape, newShapeValue);
614
615 // We cannot infer from the total number of elements so we must take the
616 // shape attribute as exact.
617 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
618 inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
619 return success();
620 }
621
622 // Determine the number of elements covered by the slice of all static
623 // dimensions. This allows us to infer the length of the remaining dynamic
624 // dimension.
625 int64_t numElements = inputShape.getNumElements();
626 int64_t staticMul = 1;
627 for (auto val : newShapeValue) {
628 if (val != ShapedType::kDynamicSize) {
629 staticMul *= val;
630 }
631 }
632
633 // Determine the length of the dynamic dimension.
634 for (auto &val : newShapeValue) {
635 if (val == ShapedType::kDynamicSize)
636 val = numElements / staticMul;
637 }
638
639 inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
640 return success();
641 }
642
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)643 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
644 MLIRContext *context, ::llvm::Optional<Location> location,
645 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
646 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
647 ShapeAdaptor inputShape = operands.getShape(0);
648 ShapeAdaptor permsShape = operands.getShape(1);
649
650 // If input rank and permutation length is unknown, the output rank is
651 // unknown.
652 if (!inputShape.hasRank() || !permsShape.hasRank() ||
653 permsShape.isDynamicDim(0)) {
654 inferredReturnShapes.push_back(ShapedTypeComponents());
655 return success();
656 }
657
658 // This would imply the number of permutations does not match the rank of the
659 // input which is illegal.
660 if (permsShape.getDimSize(0) != inputShape.getRank()) {
661 return failure();
662 }
663
664 // Without the input dims we cannot determine the output dim sizes but we
665 // can determine the output rank.
666 SmallVector<int64_t> outputShape;
667 if (!inputShape.hasRank()) {
668 outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamicSize);
669 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
670 return success();
671 }
672
673 // Rank-0 means no permutations matter.
674 if (inputShape.getRank() == 0) {
675 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
676 return success();
677 }
678
679 // Check whether the input dimensions are all the same.
680 bool allTheSame = true;
681 for (int i = 1, s = inputShape.getRank(); i < s; i++) {
682 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
683 allTheSame = false;
684 break;
685 }
686 }
687
688 // If all of the input dimensions are the same we don't care about the
689 // permutation.
690 if (allTheSame) {
691 outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
692 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
693 return success();
694 }
695
696 outputShape.resize(inputShape.getRank(), ShapedType::kDynamicSize);
697 // If the permuations are a constant we can directly determine the output
698 // shape.
699 if (ShapeAdaptor permShape = operands.getValueAsShape(1)) {
700 outputShape.reserve(inputShape.getRank());
701 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
702 outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i));
703 }
704 }
705
706 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
707 return success();
708 }
709
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)710 LogicalResult tosa::GatherOp::inferReturnTypeComponents(
711 MLIRContext *context, ::llvm::Optional<Location> location,
712 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
713 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
714 llvm::SmallVector<int64_t> outputShape;
715 outputShape.resize(3, ShapedType::kDynamicSize);
716
717 ShapeAdaptor valuesShape = operands.getShape(0);
718 if (valuesShape.hasRank()) {
719 outputShape[0] = valuesShape.getDimSize(0);
720 outputShape[2] = valuesShape.getDimSize(2);
721 }
722
723 ShapeAdaptor indicesShape = operands.getShape(1);
724 if (indicesShape.hasRank()) {
725 if (outputShape[0] == ShapedType::kDynamicSize)
726 outputShape[0] = indicesShape.getDimSize(0);
727 if (outputShape[1] == ShapedType::kDynamicSize)
728 outputShape[1] = indicesShape.getDimSize(1);
729 }
730
731 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
732 return success();
733 }
734
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)735 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
736 MLIRContext *context, ::llvm::Optional<Location> location,
737 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
738 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
739 ResizeOpAdaptor adaptor(operands, attributes);
740 llvm::SmallVector<int64_t, 4> outputShape;
741 outputShape.resize(4, ShapedType::kDynamicSize);
742
743 int32_t inHeight = ShapedType::kDynamicSize;
744 int32_t inWidth = ShapedType::kDynamicSize;
745
746 ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
747 if (inputShape.hasRank()) {
748 outputShape[0] = inputShape.getDimSize(0);
749 outputShape[3] = inputShape.getDimSize(3);
750
751 inHeight = inputShape.getDimSize(1);
752 inWidth = inputShape.getDimSize(2);
753 }
754
755 int32_t shift = adaptor.getShift();
756 llvm::SmallVector<int64_t> newShape;
757 getI64Values(adaptor.getOutputSize(), newShape);
758 outputShape[1] = newShape[0];
759 outputShape[2] = newShape[1];
760
761 llvm::SmallVector<int64_t> strideInt;
762 llvm::SmallVector<int64_t> offsetInt;
763 llvm::SmallVector<double> strideFp;
764 llvm::SmallVector<double> offsetFp;
765 getI64Values(adaptor.getOffset(), offsetInt);
766 getF64Values(adaptor.getOffsetFp(), offsetFp);
767 getI64Values(adaptor.getStride(), strideInt);
768 getF64Values(adaptor.getStrideFp(), strideFp);
769
770 // If we have a 0 zero in integers we know that the resize indexing needs to
771 // be performed in floating point. Use the floating point varient to compute
772 // the resize shape.
773 bool fpMode = strideInt[0] == 0;
774
775 // We can compute the output shape if attribute specifies unknown dimensions
776 // based on the offset and stride. If we perfectly line up to the last index
777 // we need to round up the size to include it.
778 if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && fpMode) {
779 float sizeFp = (inHeight - offsetFp[0] - 1) / strideFp[0];
780 float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
781 outputShape[1] = std::ceil(sizeFp) + round;
782 }
783
784 if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && fpMode) {
785 float sizeFp = (inWidth - offsetFp[1] - 1) / strideFp[1];
786 float round = std::floor(sizeFp) == sizeFp ? 1 : 0;
787 outputShape[2] = std::ceil(sizeFp) + round;
788 }
789
790 if (outputShape[1] == ShapedType::kDynamicSize && inHeight >= 0 && !fpMode) {
791 int64_t size = (inHeight - 1);
792 size = ((size << shift) - offsetInt[0]) / strideInt[0];
793 outputShape[1] = size + 1;
794 }
795
796 if (outputShape[2] == ShapedType::kDynamicSize && inWidth >= 0 && !fpMode) {
797 int64_t size = (inWidth - 1);
798 size = ((size << shift) - offsetInt[1]) / strideInt[1];
799 outputShape[2] = size + 1;
800 }
801
802 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
803 return success();
804 }
805
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)806 LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
807 MLIRContext *context, ::llvm::Optional<Location> location,
808 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
809 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
810 llvm::SmallVector<int64_t> outputShape;
811 outputShape.resize(3, ShapedType::kDynamicSize);
812
813 ShapeAdaptor valuesInShape = operands.getShape(0);
814 if (valuesInShape.hasRank()) {
815 outputShape[0] = valuesInShape.getDimSize(0);
816 outputShape[1] = valuesInShape.getDimSize(1);
817 outputShape[2] = valuesInShape.getDimSize(2);
818 }
819
820 ShapeAdaptor indicesShape = operands.getShape(1);
821 if (indicesShape.hasRank()) {
822 if (outputShape[0] == ShapedType::kDynamicSize)
823 outputShape[0] = indicesShape.getDimSize(0);
824 }
825
826 ShapeAdaptor inputShape = operands.getShape(2);
827 if (inputShape.hasRank()) {
828 if (outputShape[0] == ShapedType::kDynamicSize)
829 outputShape[0] = inputShape.getDimSize(0);
830 if (outputShape[2] == ShapedType::kDynamicSize)
831 outputShape[2] = inputShape.getDimSize(2);
832 }
833
834 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
835 return success();
836 }
837
ReduceInferReturnTypes(ShapeAdaptor operandShape,IntegerAttr axis,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)838 static LogicalResult ReduceInferReturnTypes(
839 ShapeAdaptor operandShape, IntegerAttr axis,
840 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
841 if (!operandShape.hasRank()) {
842 inferredReturnShapes.push_back(ShapedTypeComponents());
843 return success();
844 }
845
846 SmallVector<int64_t> outputShape;
847 operandShape.getDims(outputShape);
848 int64_t axisVal = axis.getValue().getSExtValue();
849 outputShape[axisVal] = 1;
850 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
851 return success();
852 }
853
854 #define REDUCE_SHAPE_INFER(OP) \
855 LogicalResult OP::inferReturnTypeComponents( \
856 MLIRContext *context, ::llvm::Optional<Location> location, \
857 ValueShapeRange operands, DictionaryAttr attributes, \
858 RegionRange regions, \
859 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
860 return ReduceInferReturnTypes(operands.getShape(0), \
861 attributes.get("axis").cast<IntegerAttr>(), \
862 inferredReturnShapes); \
863 }
864
865 REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)866 REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
867 REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
868 REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
869 REDUCE_SHAPE_INFER(tosa::ReduceProdOp)
870 REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
871 #undef REDUCE_SHAPE_INFER
872
873 static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
874 SmallVector<int64_t> &outShape) {
875 int64_t outRank = 0;
876 for (int i = 0, e = operands.size(); i != e; ++i) {
877 auto shape = operands.getShape(i);
878 if (!shape.hasRank()) {
879 return failure();
880 }
881 outRank = std::max<int64_t>(outRank, shape.getRank());
882 }
883
884 outShape.resize(outRank, 1);
885
886 for (int i = 0, e = operands.size(); i != e; ++i) {
887 auto shape = operands.getShape(i);
888 auto rankDiff = outShape.size() - shape.getRank();
889
890 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
891 auto dim1 = outShape[i + rankDiff];
892 auto dim2 = shape.getDimSize(i);
893 auto resolvedDim = dim1;
894
895 if (dim1 == 1) {
896 resolvedDim = dim2;
897 } else if (dim2 == 1) {
898 resolvedDim = dim1;
899 } else if (dim1 != dim2) {
900 return failure();
901 }
902 outShape[i + rankDiff] = resolvedDim;
903 }
904 }
905
906 return success();
907 }
908
NAryInferReturnTypes(const ValueShapeRange & operands,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)909 static LogicalResult NAryInferReturnTypes(
910 const ValueShapeRange &operands,
911 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
912 llvm::SmallVector<int64_t> outShape;
913 if (resolveBroadcastShape(operands, outShape).failed()) {
914 inferredReturnShapes.push_back(ShapedTypeComponents());
915 } else {
916 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
917 }
918 return success();
919 }
920
921 #define NARY_SHAPE_INFER(OP) \
922 LogicalResult OP::inferReturnTypeComponents( \
923 MLIRContext *context, ::llvm::Optional<Location> location, \
924 ValueShapeRange operands, DictionaryAttr attributes, \
925 RegionRange regions, \
926 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
927 return NAryInferReturnTypes(operands, inferredReturnShapes); \
928 }
929
930 NARY_SHAPE_INFER(tosa::AbsOp)
NARY_SHAPE_INFER(tosa::AddOp)931 NARY_SHAPE_INFER(tosa::AddOp)
932 NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
933 NARY_SHAPE_INFER(tosa::BitwiseAndOp)
934 NARY_SHAPE_INFER(tosa::BitwiseOrOp)
935 NARY_SHAPE_INFER(tosa::BitwiseXorOp)
936 NARY_SHAPE_INFER(tosa::BitwiseNotOp)
937 NARY_SHAPE_INFER(tosa::CastOp)
938 NARY_SHAPE_INFER(tosa::CeilOp)
939 NARY_SHAPE_INFER(tosa::ClampOp)
940 NARY_SHAPE_INFER(tosa::ClzOp)
941 NARY_SHAPE_INFER(tosa::DivOp)
942 NARY_SHAPE_INFER(tosa::EqualOp)
943 NARY_SHAPE_INFER(tosa::ExpOp)
944 NARY_SHAPE_INFER(tosa::FloorOp)
945 NARY_SHAPE_INFER(tosa::GreaterEqualOp)
946 NARY_SHAPE_INFER(tosa::GreaterOp)
947 NARY_SHAPE_INFER(tosa::IdentityOp)
948 NARY_SHAPE_INFER(tosa::LogOp)
949 NARY_SHAPE_INFER(tosa::LogicalAndOp)
950 NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
951 NARY_SHAPE_INFER(tosa::LogicalNotOp)
952 NARY_SHAPE_INFER(tosa::LogicalOrOp)
953 NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
954 NARY_SHAPE_INFER(tosa::LogicalXorOp)
955 NARY_SHAPE_INFER(tosa::MaximumOp)
956 NARY_SHAPE_INFER(tosa::MinimumOp)
957 NARY_SHAPE_INFER(tosa::MulOp)
958 NARY_SHAPE_INFER(tosa::NegateOp)
959 NARY_SHAPE_INFER(tosa::PowOp)
960 NARY_SHAPE_INFER(tosa::ReciprocalOp)
961 NARY_SHAPE_INFER(tosa::ReluNOp)
962 NARY_SHAPE_INFER(tosa::RescaleOp)
963 NARY_SHAPE_INFER(tosa::ReverseOp)
964 NARY_SHAPE_INFER(tosa::RsqrtOp)
965 NARY_SHAPE_INFER(tosa::SelectOp)
966 NARY_SHAPE_INFER(tosa::SubOp)
967 NARY_SHAPE_INFER(tosa::TanhOp)
968 NARY_SHAPE_INFER(tosa::SigmoidOp)
969 #undef PRED_SHAPE_INFER
970
971 static LogicalResult poolingInferReturnTypes(
972 const ValueShapeRange &operands, DictionaryAttr attributes,
973 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
974 ShapeAdaptor inputShape = operands.getShape(0);
975 llvm::SmallVector<int64_t> outputShape;
976 outputShape.resize(4, -1);
977
978 // We only know the rank if the input type is unranked.
979 if (!inputShape) {
980 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
981 return success();
982 }
983
984 // Batch and number of channels are identical for pooling layer.
985 outputShape[0] = inputShape.getDimSize(0);
986 outputShape[3] = inputShape.getDimSize(3);
987
988 int32_t height = inputShape.getDimSize(1);
989 int32_t width = inputShape.getDimSize(2);
990
991 llvm::SmallVector<int64_t> kernel;
992 llvm::SmallVector<int64_t> stride;
993 llvm::SmallVector<int64_t> pad;
994
995 getI64Values(attributes.get("kernel").cast<ArrayAttr>(), kernel);
996 getI64Values(attributes.get("stride").cast<ArrayAttr>(), stride);
997 getI64Values(attributes.get("pad").cast<ArrayAttr>(), pad);
998
999 if (height != -1) {
1000 int32_t padded = height + pad[0] + pad[1] - kernel[0];
1001 outputShape[1] = padded / stride[0] + 1;
1002 }
1003
1004 if (width != -1) {
1005 int32_t padded = width + pad[2] + pad[3] - kernel[1];
1006 outputShape[2] = padded / stride[1] + 1;
1007 }
1008
1009 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1010 return success();
1011 }
1012
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1013 LogicalResult Conv2DOp::inferReturnTypeComponents(
1014 MLIRContext *context, ::llvm::Optional<Location> location,
1015 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1016 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1017 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1018 Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1019
1020 int32_t inputWidth = ShapedType::kDynamicSize;
1021 int32_t inputHeight = ShapedType::kDynamicSize;
1022 int32_t weightWidth = ShapedType::kDynamicSize;
1023 int32_t weightHeight = ShapedType::kDynamicSize;
1024
1025 // Input shape describes input width/height and batch.
1026
1027 ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1028 if (inputShape.hasRank()) {
1029 outputShape[0] = inputShape.getDimSize(0);
1030 inputHeight = inputShape.getDimSize(1);
1031 inputWidth = inputShape.getDimSize(2);
1032 }
1033
1034 // Weight shapes describes the filter width/height and the output channels.
1035 ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1036 if (weightShape.hasRank()) {
1037 outputShape[3] = weightShape.getDimSize(0);
1038 weightHeight = weightShape.getDimSize(1);
1039 weightWidth = weightShape.getDimSize(2);
1040 }
1041
1042 // Bias shape can describe the output channels.
1043 ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1044 if (biasShape.hasRank()) {
1045 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1046 ? biasShape.getDimSize(0)
1047 : outputShape[3];
1048 }
1049
1050 llvm::SmallVector<int64_t> dilation;
1051 llvm::SmallVector<int64_t> padding;
1052 llvm::SmallVector<int64_t> stride;
1053
1054 getI64Values(adaptor.getDilation(), dilation);
1055 getI64Values(adaptor.getPad(), padding);
1056 getI64Values(adaptor.getStride(), stride);
1057
1058 if (!ShapedType::isDynamic(inputHeight) &&
1059 !ShapedType::isDynamic(weightHeight)) {
1060 int32_t inputSize = inputHeight + padding[0] + padding[1];
1061 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1062 int32_t unstridedResult = inputSize - filterSize + 1;
1063 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1064 }
1065
1066 if (!ShapedType::isDynamic(inputWidth) &&
1067 !ShapedType::isDynamic(weightWidth)) {
1068 int32_t inputSize = inputWidth + padding[2] + padding[3];
1069 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1070 int32_t unstridedResult = inputSize - filterSize + 1;
1071 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1072 }
1073
1074 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1075 return success();
1076 }
1077
verify()1078 LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
1079
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1080 LogicalResult Conv3DOp::inferReturnTypeComponents(
1081 MLIRContext *context, ::llvm::Optional<Location> location,
1082 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1083 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1084 llvm::SmallVector<int64_t> outputShape(5, ShapedType::kDynamicSize);
1085 Conv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1086
1087 int32_t inputWidth = ShapedType::kDynamicSize;
1088 int32_t inputHeight = ShapedType::kDynamicSize;
1089 int32_t inputDepth = ShapedType::kDynamicSize;
1090
1091 int32_t weightWidth = ShapedType::kDynamicSize;
1092 int32_t weightHeight = ShapedType::kDynamicSize;
1093 int32_t weightDepth = ShapedType::kDynamicSize;
1094
1095 // Input shape describes input width/height and batch.
1096 ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1097 if (inputShape.hasRank()) {
1098 outputShape[0] = inputShape.getDimSize(0);
1099 inputHeight = inputShape.getDimSize(1);
1100 inputWidth = inputShape.getDimSize(2);
1101 inputDepth = inputShape.getDimSize(3);
1102 }
1103
1104 // Weight shapes describes the filter width/height and the output channels.
1105 ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1106 if (weightShape.hasRank()) {
1107 outputShape[4] = weightShape.getDimSize(0);
1108 weightHeight = weightShape.getDimSize(1);
1109 weightWidth = weightShape.getDimSize(2);
1110 weightDepth = weightShape.getDimSize(3);
1111 }
1112
1113 // Bias shape can describe the output channels.
1114 ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1115 if (biasShape.hasRank()) {
1116 outputShape[4] =
1117 (outputShape[4] == -1) ? biasShape.getDimSize(0) : outputShape[4];
1118 }
1119
1120 llvm::SmallVector<int64_t> dilation;
1121 llvm::SmallVector<int64_t> padding;
1122 llvm::SmallVector<int64_t> stride;
1123
1124 getI64Values(adaptor.getDilation(), dilation);
1125 getI64Values(adaptor.getPad(), padding);
1126 getI64Values(adaptor.getStride(), stride);
1127
1128 if (!ShapedType::isDynamic(inputHeight) &&
1129 !ShapedType::isDynamic(weightHeight)) {
1130 int32_t inputSize = inputHeight + padding[0] + padding[1];
1131 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1132 int32_t unstridedResult = inputSize - filterSize + 1;
1133 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1134 }
1135
1136 if (!ShapedType::isDynamic(inputWidth) &&
1137 !ShapedType::isDynamic(weightWidth)) {
1138 int32_t inputSize = inputWidth + padding[2] + padding[3];
1139 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1140 int32_t unstridedResult = inputSize - filterSize + 1;
1141 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1142 }
1143
1144 if (!ShapedType::isDynamic(inputDepth) &&
1145 !ShapedType::isDynamic(weightDepth)) {
1146 int32_t inputSize = inputDepth + padding[4] + padding[5];
1147 int32_t filterSize = (weightDepth - 1) * dilation[2] + 1;
1148 int32_t unstridedResult = inputSize - filterSize + 1;
1149 outputShape[3] = (unstridedResult - 1) / stride[2] + 1;
1150 }
1151
1152 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1153 return success();
1154 }
1155
verify()1156 LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
1157
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1158 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
1159 MLIRContext *context, ::llvm::Optional<Location> location,
1160 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1161 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1162 return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1163 }
1164
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1165 LogicalResult MaxPool2dOp::inferReturnTypeComponents(
1166 MLIRContext *context, ::llvm::Optional<Location> location,
1167 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1168 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1169 return poolingInferReturnTypes(operands, attributes, inferredReturnShapes);
1170 }
1171
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1172 LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
1173 MLIRContext *context, ::llvm::Optional<Location> location,
1174 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1175 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1176 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamicSize);
1177 DepthwiseConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1178
1179 int32_t inputWidth = ShapedType::kDynamicSize;
1180 int32_t inputHeight = ShapedType::kDynamicSize;
1181 int32_t inputChannels = ShapedType::kDynamicSize;
1182
1183 int32_t weightWidth = ShapedType::kDynamicSize;
1184 int32_t weightHeight = ShapedType::kDynamicSize;
1185 int32_t depthChannels = ShapedType::kDynamicSize;
1186
1187 // Input shape describes input width/height and batch.
1188 ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1189 if (inputShape.hasRank()) {
1190 outputShape[0] = inputShape.getDimSize(0);
1191 inputHeight = inputShape.getDimSize(1);
1192 inputWidth = inputShape.getDimSize(2);
1193 inputChannels = inputShape.getDimSize(3);
1194 }
1195
1196 // Weight shapes describes the filter width/height and the output channels.
1197 ShapeAdaptor weightShape = operands.getShape(adaptor.getWeight());
1198 if (weightShape.hasRank()) {
1199 weightHeight = weightShape.getDimSize(0);
1200 weightWidth = weightShape.getDimSize(1);
1201 inputChannels = ShapedType::isDynamic(inputChannels)
1202 ? weightShape.getDimSize(2)
1203 : inputChannels;
1204 depthChannels = weightShape.getDimSize(3);
1205 }
1206
1207 // If both inputChannels and depthChannels are available we can determine
1208 // the output channels.
1209 if (!ShapedType::isDynamic(inputChannels) &&
1210 !ShapedType::isDynamic(depthChannels)) {
1211 outputShape[3] = inputChannels * depthChannels;
1212 }
1213
1214 // Bias shape can describe the output channels.
1215 ShapeAdaptor biasShape = operands.getShape(adaptor.getBias());
1216 if (biasShape.hasRank()) {
1217 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1218 ? biasShape.getDimSize(0)
1219 : outputShape[3];
1220 }
1221
1222 llvm::SmallVector<int64_t> dilation;
1223 llvm::SmallVector<int64_t> padding;
1224 llvm::SmallVector<int64_t> stride;
1225
1226 getI64Values(adaptor.getDilation(), dilation);
1227 getI64Values(adaptor.getPad(), padding);
1228 getI64Values(adaptor.getStride(), stride);
1229
1230 if (!ShapedType::isDynamic(inputHeight) &&
1231 !ShapedType::isDynamic(weightHeight)) {
1232 int32_t inputSize = inputHeight + padding[0] + padding[1];
1233 int32_t filterSize = (weightHeight - 1) * dilation[0] + 1;
1234 int32_t unstridedResult = inputSize - filterSize + 1;
1235 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
1236 }
1237
1238 if (!ShapedType::isDynamic(inputWidth) &&
1239 !ShapedType::isDynamic(weightWidth)) {
1240 int32_t inputSize = inputWidth + padding[2] + padding[3];
1241 int32_t filterSize = (weightWidth - 1) * dilation[1] + 1;
1242 int32_t unstridedResult = inputSize - filterSize + 1;
1243 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
1244 }
1245
1246 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1247 return success();
1248 }
1249
verify()1250 LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
1251
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1252 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
1253 MLIRContext *context, ::llvm::Optional<Location> location,
1254 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1255 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1256 TransposeConv2DOp::Adaptor adaptor(operands.getValues(), attributes);
1257 llvm::SmallVector<int64_t> outputShape;
1258 getI64Values(adaptor.getOutShape(), outputShape);
1259
1260 int32_t inputWidth = ShapedType::kDynamicSize;
1261 int32_t inputHeight = ShapedType::kDynamicSize;
1262 int32_t weightWidth = ShapedType::kDynamicSize;
1263 int32_t weightHeight = ShapedType::kDynamicSize;
1264
1265 // Input shape describes input width/height and batch.
1266 ShapeAdaptor inputShape = operands.getShape(adaptor.getInput());
1267 if (inputShape.hasRank()) {
1268 outputShape[0] = ShapedType::isDynamic(outputShape[0])
1269 ? inputShape.getDimSize(0)
1270 : outputShape[0];
1271 inputHeight = inputShape.getDimSize(1);
1272 inputWidth = inputShape.getDimSize(2);
1273 }
1274
1275 // Weight shapes describes the filter width/height and the output channels.
1276 ShapeAdaptor weightShape = operands.getShape(adaptor.getFilter());
1277 if (weightShape.hasRank()) {
1278 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1279 ? weightShape.getDimSize(0)
1280 : outputShape[3];
1281 weightHeight = weightShape.getDimSize(1);
1282 weightWidth = weightShape.getDimSize(2);
1283 }
1284
1285 // Bias shape can describe the output channels.
1286 ShapeAdaptor biasShape = operands.getShape(adaptor.getInput());
1287 if (biasShape.hasRank()) {
1288 outputShape[3] = ShapedType::isDynamic(outputShape[3])
1289 ? biasShape.getDimSize(0)
1290 : outputShape[3];
1291 }
1292
1293 llvm::SmallVector<int64_t> padding;
1294 llvm::SmallVector<int64_t> stride;
1295
1296 getI64Values(adaptor.getOutPad(), padding);
1297 getI64Values(adaptor.getStride(), stride);
1298
1299 if (!ShapedType::isDynamic(inputHeight) &&
1300 !ShapedType::isDynamic(weightHeight)) {
1301 int32_t calculateSize =
1302 (inputHeight - 1) * stride[0] - padding[0] - padding[1] + weightHeight;
1303 outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1];
1304 }
1305
1306 if (!ShapedType::isDynamic(inputWidth) &&
1307 !ShapedType::isDynamic(weightWidth)) {
1308 int32_t calculateSize =
1309 (inputWidth - 1) * stride[1] - padding[2] - padding[3] + weightWidth;
1310 outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2];
1311 }
1312
1313 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1314 return success();
1315 }
1316
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1317 LogicalResult IfOp::inferReturnTypeComponents(
1318 MLIRContext *context, ::llvm::Optional<Location> location,
1319 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1320 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1321 llvm::SmallVector<tosa::YieldOp> yieldOps;
1322 for (Region *region : regions) {
1323 for (auto &block : *region)
1324 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1325 yieldOps.push_back(returnOp);
1326 }
1327
1328 if (yieldOps.empty())
1329 return failure();
1330
1331 // Get the initial type information for the yield op.
1332 llvm::SmallVector<ValueKnowledge> resultKnowledge;
1333 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1334 for (auto operand : yieldOps.front().getOperands()) {
1335 resultKnowledge.push_back(
1336 ValueKnowledge::getKnowledgeFromType(operand.getType()));
1337 }
1338
1339 for (auto yieldOp : yieldOps) {
1340 if (resultKnowledge.size() != yieldOp.getNumOperands())
1341 return failure();
1342
1343 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1344 int32_t index = it.index();
1345 auto meet = ValueKnowledge::meet(
1346 resultKnowledge[index],
1347 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
1348 if (!meet)
1349 continue;
1350 resultKnowledge[index] = meet;
1351 }
1352 }
1353
1354 for (const ValueKnowledge &result : resultKnowledge) {
1355 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1356 }
1357
1358 return success();
1359 }
1360
inferReturnTypeComponents(MLIRContext * context,::llvm::Optional<Location> location,ValueShapeRange operands,DictionaryAttr attributes,RegionRange regions,SmallVectorImpl<ShapedTypeComponents> & inferredReturnShapes)1361 LogicalResult WhileOp::inferReturnTypeComponents(
1362 MLIRContext *context, ::llvm::Optional<Location> location,
1363 ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
1364 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1365 llvm::SmallVector<tosa::YieldOp> yieldOps;
1366 for (auto &block : *regions[1])
1367 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
1368 yieldOps.push_back(returnOp);
1369
1370 // TOSA's while must have a tosa.yield as its terminator. If not found this
1371 // tosa.while is invalid.
1372 if (yieldOps.empty())
1373 return failure();
1374
1375 // Get the initial type information from the operand types.
1376 llvm::SmallVector<ValueKnowledge> resultKnowledge;
1377 resultKnowledge.reserve(yieldOps.front().getNumOperands());
1378 for (auto operand : yieldOps.front().getOperands()) {
1379 resultKnowledge.push_back(
1380 ValueKnowledge::getKnowledgeFromType(operand.getType()));
1381 }
1382
1383 for (auto yieldOp : yieldOps) {
1384 if (resultKnowledge.size() != yieldOp.getNumOperands())
1385 return failure();
1386
1387 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
1388 int32_t index = it.index();
1389 if (auto meet = ValueKnowledge::meet(
1390 resultKnowledge[index],
1391 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
1392 resultKnowledge[index] = meet;
1393 }
1394 }
1395 }
1396
1397 for (const ValueKnowledge &result : resultKnowledge) {
1398 inferredReturnShapes.push_back(result.getShapedTypeComponents());
1399 }
1400
1401 return success();
1402 }
1403
1404 //===----------------------------------------------------------------------===//
1405 // TOSA Attribute Definitions.
1406 //===----------------------------------------------------------------------===//
1407
1408 #define GET_ATTRDEF_CLASSES
1409 #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
1410
1411 //===----------------------------------------------------------------------===//
1412 // TOSA Operator Definitions.
1413 //===----------------------------------------------------------------------===//
1414
1415 #define GET_OP_CLASSES
1416 #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
1417