1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 #include "mlir/IR/DialectImplementation.h" 11 #include "mlir/Transforms/InliningUtils.h" 12 #include "llvm/ADT/TypeSwitch.h" 13 14 using namespace mlir; 15 using namespace mlir::tensor; 16 17 //===----------------------------------------------------------------------===// 18 // TableGen'd Attributes Methods 19 //===----------------------------------------------------------------------===// 20 21 #define GET_ATTRDEF_CLASSES 22 #include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc" 23 24 // Dictionary keys. 25 static constexpr StringRef getSparseDimLevelTypeAttrName() { 26 return "sparseDimLevelType"; 27 } 28 static constexpr StringRef getSparseDimOrderingAttrName() { 29 return "sparseDimOrdering"; 30 } 31 static constexpr StringRef getSparsePointerBitWidthAttrName() { 32 return "sparsePointerBitWidth"; 33 } 34 static constexpr StringRef getSparseIndexBitWidthAttrName() { 35 return "sparseIndexBitWidth"; 36 } 37 38 // Dictionary values. 39 static constexpr StringRef getDenseDimLevelTypeVal() { return "dense"; } 40 static constexpr StringRef getCompressedDimLevelTypeVal() { 41 return "compressed"; 42 } 43 static constexpr StringRef getSingletonDimLevelTypeVal() { return "singleton"; } 44 45 Attribute SparseTensorEncodingAttr::parse(MLIRContext *context, 46 DialectAsmParser &parser, Type type) { 47 if (failed(parser.parseLess())) 48 return {}; 49 DictionaryAttr dict; 50 if (failed(parser.parseAttribute(dict))) 51 return {}; 52 if (failed(parser.parseGreater())) 53 return {}; 54 return SparseTensorEncodingAttr::get(context, dict); 55 } 56 57 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const { 58 printer << "sparse<" << getDict() << ">"; 59 } 60 61 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 62 llvm::ArrayRef<int64_t> shape, Type elementType, 63 llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const { 64 unsigned size = shape.size(); 65 for (const NamedAttribute &attr : getDict()) { 66 if (attr.first == getSparseDimLevelTypeAttrName()) { 67 // Dimension level type verification. 68 auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); 69 if (!arrayAttr || size != static_cast<int64_t>(arrayAttr.size())) 70 return emitError() << "expected an array of size " << size 71 << " for dimension level types"; 72 for (unsigned i = 0; i < size; i++) { 73 auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); 74 if (!strAttr) 75 return emitError() 76 << "expected string value in dimension level types"; 77 auto strVal = strAttr.getValue(); 78 if (strVal != getDenseDimLevelTypeVal() && 79 strVal != getCompressedDimLevelTypeVal() && 80 strVal != getSingletonDimLevelTypeVal()) 81 return emitError() << "unexpected dimension level type: " << strAttr; 82 } 83 } else if (attr.first == getSparseDimOrderingAttrName()) { 84 // Dimension order verification. 85 auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); 86 if (!affineAttr) 87 return emitError() << "expected an affine map for dimension ordering"; 88 AffineMap map = affineAttr.getValue(); 89 if (size != map.getNumResults() || !map.isPermutation()) 90 return emitError() << "expected a permutation affine map of size " 91 << size << " for dimension ordering"; 92 } else if (attr.first == getSparsePointerBitWidthAttrName() || 93 attr.first == getSparseIndexBitWidthAttrName()) { 94 // Pointer or index bitwidth verification. 95 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 96 if (!intAttr) 97 return emitError() << "expected an integral bitwidth"; 98 switch (intAttr.getInt()) { 99 case 0: 100 case 8: 101 case 16: 102 case 32: 103 case 64: 104 continue; 105 default: 106 return emitError() << "unexpected bitwidth: " << intAttr.getInt(); 107 } 108 } else { 109 return emitError() << "unexpected key: " << attr.first.str(); 110 } 111 } 112 return success(); 113 } 114 115 SparseTensorEncodingAttr::DimLevelType 116 SparseTensorEncodingAttr::getDimLevelType(unsigned dim) const { 117 if (auto value = getDict().get(getSparseDimLevelTypeAttrName())) { 118 auto strVal = 119 value.dyn_cast<ArrayAttr>()[dim].cast<StringAttr>().getValue(); 120 if (strVal == getCompressedDimLevelTypeVal()) 121 return DimLevelType::Compressed; 122 if (strVal == getSingletonDimLevelTypeVal()) 123 return DimLevelType::Singleton; 124 } 125 return DimLevelType::Dense; 126 } 127 128 AffineMap SparseTensorEncodingAttr::getDimOrdering() const { 129 if (auto value = getDict().get(getSparseDimOrderingAttrName())) 130 return value.cast<AffineMapAttr>().getValue(); 131 return {}; 132 } 133 134 unsigned SparseTensorEncodingAttr::getPointerBitWidth() const { 135 if (auto value = getDict().get(getSparsePointerBitWidthAttrName())) 136 return value.cast<IntegerAttr>().getInt(); 137 return 0; 138 } 139 140 unsigned SparseTensorEncodingAttr::getIndexBitWidth() const { 141 if (auto value = getDict().get(getSparseIndexBitWidthAttrName())) 142 return value.cast<IntegerAttr>().getInt(); 143 return 0; 144 } 145 146 //===----------------------------------------------------------------------===// 147 // TensorDialect Dialect Interfaces 148 //===----------------------------------------------------------------------===// 149 150 namespace { 151 struct TensorInlinerInterface : public DialectInlinerInterface { 152 using DialectInlinerInterface::DialectInlinerInterface; 153 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 154 BlockAndValueMapping &valueMapping) const final { 155 return true; 156 } 157 bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, 158 BlockAndValueMapping &) const final { 159 return true; 160 } 161 }; 162 } // end anonymous namespace 163 164 //===----------------------------------------------------------------------===// 165 // TensorDialect Methods 166 //===----------------------------------------------------------------------===// 167 168 void TensorDialect::initialize() { 169 addAttributes< 170 #define GET_ATTRDEF_LIST 171 #include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc" 172 >(); 173 addOperations< 174 #define GET_OP_LIST 175 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc" 176 >(); 177 addInterfaces<TensorInlinerInterface>(); 178 } 179 180 Attribute TensorDialect::parseAttribute(DialectAsmParser &parser, 181 Type type) const { 182 StringRef attrTag; 183 if (failed(parser.parseKeyword(&attrTag))) 184 return Attribute(); 185 Attribute attr; 186 auto parseResult = 187 generatedAttributeParser(getContext(), parser, attrTag, type, attr); 188 if (parseResult.hasValue()) 189 return attr; 190 parser.emitError(parser.getNameLoc(), "unknown tensor attribute"); 191 return Attribute(); 192 } 193 194 void TensorDialect::printAttribute(::mlir::Attribute attr, 195 ::mlir::DialectAsmPrinter &printer) const { 196 if (succeeded(generatedAttributePrinter(attr, printer))) 197 return; 198 } 199