1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributeInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "llvm/ADT/Sequence.h" 12 13 using namespace mlir; 14 using namespace mlir::detail; 15 16 //===----------------------------------------------------------------------===// 17 /// Tablegen Interface Definitions 18 //===----------------------------------------------------------------------===// 19 20 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // ElementsAttr 24 //===----------------------------------------------------------------------===// 25 26 ShapedType ElementsAttr::getType() const { 27 return Attribute::getType().cast<ShapedType>(); 28 } 29 30 Type ElementsAttr::getElementType(Attribute elementsAttr) { 31 return elementsAttr.getType().cast<ShapedType>().getElementType(); 32 } 33 34 int64_t ElementsAttr::getNumElements(Attribute elementsAttr) { 35 return elementsAttr.getType().cast<ShapedType>().getNumElements(); 36 } 37 38 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) { 39 // Verify that the rank of the indices matches the held type. 40 int64_t rank = type.getRank(); 41 if (rank == 0 && index.size() == 1 && index[0] == 0) 42 return true; 43 if (rank != static_cast<int64_t>(index.size())) 44 return false; 45 46 // Verify that all of the indices are within the shape dimensions. 47 ArrayRef<int64_t> shape = type.getShape(); 48 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 49 int64_t dim = static_cast<int64_t>(index[i]); 50 return 0 <= dim && dim < shape[i]; 51 }); 52 } 53 bool ElementsAttr::isValidIndex(Attribute elementsAttr, 54 ArrayRef<uint64_t> index) { 55 return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index); 56 } 57 58 uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr, 59 ArrayRef<uint64_t> index) { 60 ShapedType type = elementsAttr.getType().cast<ShapedType>(); 61 assert(isValidIndex(type, index) && "expected valid multi-dimensional index"); 62 63 // Reduce the provided multidimensional index into a flattended 1D row-major 64 // index. 65 auto rank = type.getRank(); 66 auto shape = type.getShape(); 67 uint64_t valueIndex = 0; 68 uint64_t dimMultiplier = 1; 69 for (int i = rank - 1; i >= 0; --i) { 70 valueIndex += index[i] * dimMultiplier; 71 dimMultiplier *= shape[i]; 72 } 73 return valueIndex; 74 } 75