1 //===- BuiltinAttributeInterfaces.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/BuiltinAttributeInterfaces.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "llvm/ADT/Sequence.h" 13 14 using namespace mlir; 15 using namespace mlir::detail; 16 17 //===----------------------------------------------------------------------===// 18 /// Tablegen Interface Definitions 19 //===----------------------------------------------------------------------===// 20 21 #include "mlir/IR/BuiltinAttributeInterfaces.cpp.inc" 22 23 //===----------------------------------------------------------------------===// 24 // ElementsAttr 25 //===----------------------------------------------------------------------===// 26 27 ShapedType ElementsAttr::getType() const { 28 return Attribute::getType().cast<ShapedType>(); 29 } 30 31 Type ElementsAttr::getElementType(Attribute elementsAttr) { 32 return elementsAttr.getType().cast<ShapedType>().getElementType(); 33 } 34 35 int64_t ElementsAttr::getNumElements(Attribute elementsAttr) { 36 return elementsAttr.getType().cast<ShapedType>().getNumElements(); 37 } 38 39 bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef<uint64_t> index) { 40 // Verify that the rank of the indices matches the held type. 41 int64_t rank = type.getRank(); 42 if (rank == 0 && index.size() == 1 && index[0] == 0) 43 return true; 44 if (rank != static_cast<int64_t>(index.size())) 45 return false; 46 47 // Verify that all of the indices are within the shape dimensions. 48 ArrayRef<int64_t> shape = type.getShape(); 49 return llvm::all_of(llvm::seq<int>(0, rank), [&](int i) { 50 int64_t dim = static_cast<int64_t>(index[i]); 51 return 0 <= dim && dim < shape[i]; 52 }); 53 } 54 bool ElementsAttr::isValidIndex(Attribute elementsAttr, 55 ArrayRef<uint64_t> index) { 56 return isValidIndex(elementsAttr.getType().cast<ShapedType>(), index); 57 } 58 59 uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef<uint64_t> index) { 60 ShapedType shapeType = type.cast<ShapedType>(); 61 assert(isValidIndex(shapeType, index) && 62 "expected valid multi-dimensional index"); 63 64 // Reduce the provided multidimensional index into a flattended 1D row-major 65 // index. 66 auto rank = shapeType.getRank(); 67 ArrayRef<int64_t> shape = shapeType.getShape(); 68 uint64_t valueIndex = 0; 69 uint64_t dimMultiplier = 1; 70 for (int i = rank - 1; i >= 0; --i) { 71 valueIndex += index[i] * dimMultiplier; 72 dimMultiplier *= shape[i]; 73 } 74 return valueIndex; 75 } 76 77 //===----------------------------------------------------------------------===// 78 // MemRefLayoutAttrInterface 79 //===----------------------------------------------------------------------===// 80 81 LogicalResult mlir::detail::verifyAffineMapAsLayout( 82 AffineMap m, ArrayRef<int64_t> shape, 83 function_ref<InFlightDiagnostic()> emitError) { 84 if (m.getNumDims() != shape.size()) 85 return emitError() << "memref layout mismatch between rank and affine map: " 86 << shape.size() << " != " << m.getNumDims(); 87 88 return success(); 89 } 90