1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributeInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "llvm/ADT/Sequence.h"
12 
13 using namespace mlir;
14 using namespace mlir::detail;
15 
16 //===----------------------------------------------------------------------===//
17 /// Tablegen Interface Definitions
18 //===----------------------------------------------------------------------===//
19 
20 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
21 
22 //===----------------------------------------------------------------------===//
23 // ElementsAttr
24 //===----------------------------------------------------------------------===//
25 
26 ShapedType ElementsAttr::getType() const {
27   return Attribute::getType().cast<ShapedType>();
28 }
29 
30 Type ElementsAttr::getElementType(Attribute elementsAttr) {
31   return elementsAttr.getType().cast<ShapedType>().getElementType();
32 }
33 
34 int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
35   return elementsAttr.getType().cast<ShapedType>().getNumElements();
36 }
37 
38 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
39   // Verify that the rank of the indices matches the held type.
40   int64_t rank = type.getRank();
41   if (rank == 0 && index.size() == 1 && index[0] == 0)
42     return true;
43   if (rank != static_cast<int64_t>(index.size()))
44     return false;
45 
46   // Verify that all of the indices are within the shape dimensions.
47   ArrayRef<int64_t> shape = type.getShape();
48   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
49     int64_t dim = static_cast<int64_t>(index[i]);
50     return 0 <= dim && dim < shape[i];
51   });
52 }
53 bool ElementsAttr::isValidIndex(Attribute elementsAttr,
54                                 ArrayRef<uint64_t> index) {
55   return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
56 }
57 
58 uint64_t ElementsAttr::getFlattenedIndex(Attribute elementsAttr,
59                                          ArrayRef<uint64_t> index) {
60   ShapedType type = elementsAttr.getType().cast<ShapedType>();
61   assert(isValidIndex(type, index) && "expected valid multi-dimensional index");
62 
63   // Reduce the provided multidimensional index into a flattended 1D row-major
64   // index.
65   auto rank = type.getRank();
66   auto shape = type.getShape();
67   uint64_t valueIndex = 0;
68   uint64_t dimMultiplier = 1;
69   for (int i = rank - 1; i >= 0; --i) {
70     valueIndex += index[i] * dimMultiplier;
71     dimMultiplier *= shape[i];
72   }
73   return valueIndex;
74 }
75