1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "llvm/ADT/TypeSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 //===----------------------------------------------------------------------===// 20 // TensorDialect Attribute Methods. 21 //===----------------------------------------------------------------------===// 22 23 #define GET_ATTRDEF_CLASSES 24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 25 26 static bool acceptBitWidth(unsigned bitWidth) { 27 switch (bitWidth) { 28 case 0: 29 case 8: 30 case 16: 31 case 32: 32 case 64: 33 return true; 34 default: 35 return false; 36 } 37 } 38 39 Attribute SparseTensorEncodingAttr::parse(MLIRContext *context, 40 DialectAsmParser &parser, Type type) { 41 if (failed(parser.parseLess())) 42 return {}; 43 // Parse the data as a dictionary. 44 DictionaryAttr dict; 45 if (failed(parser.parseAttribute(dict))) 46 return {}; 47 if (failed(parser.parseGreater())) 48 return {}; 49 // Process the data from the parsed dictionary value into struct-like data. 50 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 51 AffineMap map = {}; 52 unsigned ptr = 0; 53 unsigned ind = 0; 54 for (const NamedAttribute &attr : dict) { 55 if (attr.first == "dimLevelType") { 56 auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); 57 if (!arrayAttr) { 58 parser.emitError(parser.getNameLoc(), 59 "expected an array for dimension level types"); 60 return {}; 61 } 62 for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) { 63 auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); 64 if (!strAttr) { 65 parser.emitError(parser.getNameLoc(), 66 "expected a string value in dimension level types"); 67 return {}; 68 } 69 auto strVal = strAttr.getValue(); 70 if (strVal == "dense") { 71 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 72 } else if (strVal == "compressed") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 74 } else if (strVal == "singleton") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 76 } else { 77 parser.emitError(parser.getNameLoc(), 78 "unexpected dimension level type: ") 79 << strVal; 80 return {}; 81 } 82 } 83 } else if (attr.first == "dimOrdering") { 84 auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); 85 if (!affineAttr) { 86 parser.emitError(parser.getNameLoc(), 87 "expected an affine map for dimension ordering"); 88 return {}; 89 } 90 map = affineAttr.getValue(); 91 } else if (attr.first == "pointerBitWidth") { 92 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 93 if (!intAttr) { 94 parser.emitError(parser.getNameLoc(), 95 "expected an integral pointer bitwidth"); 96 return {}; 97 } 98 ptr = intAttr.getInt(); 99 } else if (attr.first == "indexBitWidth") { 100 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 101 if (!intAttr) { 102 parser.emitError(parser.getNameLoc(), 103 "expected an integral index bitwidth"); 104 return {}; 105 } 106 ind = intAttr.getInt(); 107 } else { 108 parser.emitError(parser.getNameLoc(), "unexpected key: ") 109 << attr.first.str(); 110 return {}; 111 } 112 } 113 // Construct struct-like storage for attribute. 114 return parser.getChecked<SparseTensorEncodingAttr>(context, dlt, map, ptr, 115 ind); 116 } 117 118 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const { 119 // Print the struct-like storage in dictionary fashion. 120 printer << "encoding<{ dimLevelType = [ "; 121 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 122 switch (getDimLevelType()[i]) { 123 case DimLevelType::Dense: 124 printer << "\"dense\""; 125 break; 126 case DimLevelType::Compressed: 127 printer << "\"compressed\""; 128 break; 129 case DimLevelType::Singleton: 130 printer << "\"singleton\""; 131 break; 132 } 133 if (i != e - 1) 134 printer << ", "; 135 } 136 printer << " ]"; 137 if (getDimOrdering()) 138 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 139 printer << ", pointerBitWidth = " << getPointerBitWidth() 140 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 141 } 142 143 LogicalResult SparseTensorEncodingAttr::verify( 144 function_ref<InFlightDiagnostic()> emitError, 145 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 146 unsigned pointerBitWidth, unsigned indexBitWidth) { 147 if (!acceptBitWidth(pointerBitWidth)) 148 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 149 if (!acceptBitWidth(indexBitWidth)) 150 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 151 if (dimOrdering) { 152 if (!dimOrdering.isPermutation()) 153 return emitError() 154 << "expected a permutation affine map for dimension ordering"; 155 if (dimOrdering.getNumResults() != dimLevelType.size()) 156 return emitError() << "unexpected mismatch in ordering and dimension " 157 "level types size"; 158 } 159 return success(); 160 } 161 162 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 163 ArrayRef<int64_t> shape, Type elementType, 164 function_ref<InFlightDiagnostic()> emitError) const { 165 // Check structural integrity. 166 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 167 getPointerBitWidth(), getIndexBitWidth()))) 168 return failure(); 169 // Check integrity with tensor type specifics. Dimension ordering is optional, 170 // but we always should have dimension level types for the full rank. 171 unsigned size = shape.size(); 172 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 173 return emitError() << "expected an affine map of size " << size 174 << " for dimension ordering"; 175 if (getDimLevelType().size() != size) 176 return emitError() << "expected an array of size " << size 177 << " for dimension level types"; 178 return success(); 179 } 180 181 SparseTensorEncodingAttr 182 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 183 if (auto ttp = type.dyn_cast<RankedTensorType>()) 184 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 185 return nullptr; 186 } 187 188 //===----------------------------------------------------------------------===// 189 // TensorDialect Operations. 190 //===----------------------------------------------------------------------===// 191 192 static LogicalResult isInBounds(Value dim, Value tensor) { 193 if (auto constantOp = dim.getDefiningOp<ConstantOp>()) { 194 unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt(); 195 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 196 return failure(); 197 } 198 return success(); // in bounds, or symbolic 199 } 200 201 static LogicalResult isMatchingWidth(Value result, unsigned width) { 202 Type etp = result.getType().cast<MemRefType>().getElementType(); 203 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 204 return success(); 205 return failure(); 206 } 207 208 static LogicalResult verify(NewOp op) { 209 if (!getSparseTensorEncoding(op.getResult().getType())) 210 return op.emitError("expected a sparse tensor result"); 211 return success(); 212 } 213 214 static LogicalResult verify(ToPointersOp op) { 215 if (failed(isInBounds(op.dim(), op.tensor()))) 216 return op.emitError("requested pointers dimension out of bounds"); 217 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 218 if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 219 return op.emitError("unexpected type for pointers"); 220 return success(); 221 } 222 return op.emitError("expected a sparse tensor to get pointers"); 223 } 224 225 static LogicalResult verify(ToIndicesOp op) { 226 if (failed(isInBounds(op.dim(), op.tensor()))) 227 return op.emitError("requested indices dimension out of bounds"); 228 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 229 if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 230 return op.emitError("unexpected type for indices"); 231 return success(); 232 } 233 return op.emitError("expected a sparse tensor to get indices"); 234 } 235 236 static LogicalResult verify(ToValuesOp op) { 237 if (!getSparseTensorEncoding(op.tensor().getType())) 238 return op.emitError("expected a sparse tensor to get values"); 239 RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 240 MemRefType mtp = op.result().getType().cast<MemRefType>(); 241 if (ttp.getElementType() != mtp.getElementType()) 242 return op.emitError("unexpected mismatch in element types"); 243 return success(); 244 } 245 246 //===----------------------------------------------------------------------===// 247 // TensorDialect Methods. 248 //===----------------------------------------------------------------------===// 249 250 void SparseTensorDialect::initialize() { 251 addAttributes< 252 #define GET_ATTRDEF_LIST 253 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 254 >(); 255 addOperations< 256 #define GET_OP_LIST 257 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 258 >(); 259 } 260 261 #define GET_OP_CLASSES 262 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 263 264 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, 265 Type type) const { 266 StringRef attrTag; 267 if (failed(parser.parseKeyword(&attrTag))) 268 return Attribute(); 269 Attribute attr; 270 auto parseResult = 271 generatedAttributeParser(getContext(), parser, attrTag, type, attr); 272 if (parseResult.hasValue()) 273 return attr; 274 parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); 275 return Attribute(); 276 } 277 278 void SparseTensorDialect::printAttribute(Attribute attr, 279 DialectAsmPrinter &printer) const { 280 if (succeeded(generatedAttributePrinter(attr, printer))) 281 return; 282 } 283