1 //===- BuiltinTypeInterfaces.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinTypes.h"
10 #include "mlir/IR/Diagnostics.h"
11 #include "llvm/ADT/Sequence.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 //===----------------------------------------------------------------------===//
17 /// Tablegen Interface Definitions
18 //===----------------------------------------------------------------------===//
19 
20 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
21 
22 //===----------------------------------------------------------------------===//
23 // ShapedType
24 //===----------------------------------------------------------------------===//
25 
26 constexpr int64_t ShapedType::kDynamicSize;
27 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
28 
getNumElements(ArrayRef<int64_t> shape)29 int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) {
30   int64_t num = 1;
31   for (int64_t dim : shape) {
32     num *= dim;
33     assert(num >= 0 && "integer overflow in element count computation");
34   }
35   return num;
36 }
37 
getSizeInBits() const38 int64_t ShapedType::getSizeInBits() const {
39   assert(hasStaticShape() &&
40          "cannot get the bit size of an aggregate with a dynamic shape");
41 
42   auto elementType = getElementType();
43   if (elementType.isIntOrFloat())
44     return elementType.getIntOrFloatBitWidth() * getNumElements();
45 
46   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
47     elementType = complexType.getElementType();
48     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
49   }
50   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
51 }
52