1 //===- BuiltinTypeInterfaces.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinTypes.h" 10 #include "mlir/IR/Diagnostics.h" 11 #include "llvm/ADT/Sequence.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 //===----------------------------------------------------------------------===// 17 /// Tablegen Interface Definitions 18 //===----------------------------------------------------------------------===// 19 20 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // ShapedType 24 //===----------------------------------------------------------------------===// 25 26 constexpr int64_t ShapedType::kDynamicSize; 27 constexpr int64_t ShapedType::kDynamicStrideOrOffset; 28 getNumElements(ArrayRef<int64_t> shape)29int64_t ShapedType::getNumElements(ArrayRef<int64_t> shape) { 30 int64_t num = 1; 31 for (int64_t dim : shape) { 32 num *= dim; 33 assert(num >= 0 && "integer overflow in element count computation"); 34 } 35 return num; 36 } 37 getSizeInBits() const38int64_t ShapedType::getSizeInBits() const { 39 assert(hasStaticShape() && 40 "cannot get the bit size of an aggregate with a dynamic shape"); 41 42 auto elementType = getElementType(); 43 if (elementType.isIntOrFloat()) 44 return elementType.getIntOrFloatBitWidth() * getNumElements(); 45 46 if (auto complexType = elementType.dyn_cast<ComplexType>()) { 47 elementType = complexType.getElementType(); 48 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; 49 } 50 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); 51 } 52