1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/BuiltinAttributeInterfaces.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "llvm/ADT/Sequence.h"
13 
14 using namespace mlir;
15 using namespace mlir::detail;
16 
17 //===----------------------------------------------------------------------===//
18 /// Tablegen Interface Definitions
19 //===----------------------------------------------------------------------===//
20 
21 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // ElementsAttr
25 //===----------------------------------------------------------------------===//
26 
getType() const27 ShapedType ElementsAttr::getType() const {
28   return Attribute::getType().cast<ShapedType>();
29 }
30 
getElementType(Attribute elementsAttr)31 Type ElementsAttr::getElementType(Attribute elementsAttr) {
32   return elementsAttr.getType().cast<ShapedType>().getElementType();
33 }
34 
getNumElements(Attribute elementsAttr)35 int64_t ElementsAttr::getNumElements(Attribute elementsAttr) {
36   return elementsAttr.getType().cast<ShapedType>().getNumElements();
37 }
38 
isValidIndex(ShapedType type,ArrayRef<uint64_t> index)39 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) {
40   // Verify that the rank of the indices matches the held type.
41   int64_t rank = type.getRank();
42   if (rank == 0 && index.size() == 1 && index[0] == 0)
43     return true;
44   if (rank != static_cast<int64_t>(index.size()))
45     return false;
46 
47   // Verify that all of the indices are within the shape dimensions.
48   ArrayRef<int64_t> shape = type.getShape();
49   return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) {
50     int64_t dim = static_cast<int64_t>(index[i]);
51     return 0 <= dim && dim < shape[i];
52   });
53 }
isValidIndex(Attribute elementsAttr,ArrayRef<uint64_t> index)54 bool ElementsAttr::isValidIndex(Attribute elementsAttr,
55                                 ArrayRef<uint64_t> index) {
56   return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index);
57 }
58 
getFlattenedIndex(Type type,ArrayRef<uint64_t> index)59 uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) {
60   ShapedType shapeType = type.cast<ShapedType>();
61   assert(isValidIndex(shapeType, index) &&
62          "expected valid multi-dimensional index");
63 
64   // Reduce the provided multidimensional index into a flattended 1D row-major
65   // index.
66   auto rank = shapeType.getRank();
67   ArrayRef<int64_t> shape = shapeType.getShape();
68   uint64_t valueIndex = 0;
69   uint64_t dimMultiplier = 1;
70   for (int i = rank - 1; i >= 0; --i) {
71     valueIndex += index[i] * dimMultiplier;
72     dimMultiplier *= shape[i];
73   }
74   return valueIndex;
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // MemRefLayoutAttrInterface
79 //===----------------------------------------------------------------------===//
80 
verifyAffineMapAsLayout(AffineMap m,ArrayRef<int64_t> shape,function_ref<InFlightDiagnostic ()> emitError)81 LogicalResult mlir::detail::verifyAffineMapAsLayout(
82     AffineMap m, ArrayRef<int64_t> shape,
83     function_ref<InFlightDiagnostic()> emitError) {
84   if (m.getNumDims() != shape.size())
85     return emitError() << "memref layout mismatch between rank and affine map: "
86                        << shape.size() << " != " << m.getNumDims();
87 
88   return success();
89 }
90