1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "llvm/ADT/TypeSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // TensorDialect Attribute Methods. 23 //===----------------------------------------------------------------------===// 24 25 #define GET_ATTRDEF_CLASSES 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 27 28 static bool acceptBitWidth(unsigned bitWidth) { 29 switch (bitWidth) { 30 case 0: 31 case 8: 32 case 16: 33 case 32: 34 case 64: 35 return true; 36 default: 37 return false; 38 } 39 } 40 41 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 42 if (failed(parser.parseLess())) 43 return {}; 44 // Parse the data as a dictionary. 45 DictionaryAttr dict; 46 if (failed(parser.parseAttribute(dict))) 47 return {}; 48 if (failed(parser.parseGreater())) 49 return {}; 50 // Process the data from the parsed dictionary value into struct-like data. 51 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 52 AffineMap map = {}; 53 unsigned ptr = 0; 54 unsigned ind = 0; 55 for (const NamedAttribute &attr : dict) { 56 if (attr.getName() == "dimLevelType") { 57 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 58 if (!arrayAttr) { 59 parser.emitError(parser.getNameLoc(), 60 "expected an array for dimension level types"); 61 return {}; 62 } 63 for (auto i : arrayAttr) { 64 auto strAttr = i.dyn_cast<StringAttr>(); 65 if (!strAttr) { 66 parser.emitError(parser.getNameLoc(), 67 "expected a string value in dimension level types"); 68 return {}; 69 } 70 auto strVal = strAttr.getValue(); 71 if (strVal == "dense") { 72 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 73 } else if (strVal == "compressed") { 74 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 75 } else if (strVal == "singleton") { 76 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 77 } else { 78 parser.emitError(parser.getNameLoc(), 79 "unexpected dimension level type: ") 80 << strVal; 81 return {}; 82 } 83 } 84 } else if (attr.getName() == "dimOrdering") { 85 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 86 if (!affineAttr) { 87 parser.emitError(parser.getNameLoc(), 88 "expected an affine map for dimension ordering"); 89 return {}; 90 } 91 map = affineAttr.getValue(); 92 } else if (attr.getName() == "pointerBitWidth") { 93 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 94 if (!intAttr) { 95 parser.emitError(parser.getNameLoc(), 96 "expected an integral pointer bitwidth"); 97 return {}; 98 } 99 ptr = intAttr.getInt(); 100 } else if (attr.getName() == "indexBitWidth") { 101 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 102 if (!intAttr) { 103 parser.emitError(parser.getNameLoc(), 104 "expected an integral index bitwidth"); 105 return {}; 106 } 107 ind = intAttr.getInt(); 108 } else { 109 parser.emitError(parser.getNameLoc(), "unexpected key: ") 110 << attr.getName().strref(); 111 return {}; 112 } 113 } 114 // Construct struct-like storage for attribute. 115 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 116 map, ptr, ind); 117 } 118 119 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 120 // Print the struct-like storage in dictionary fashion. 121 printer << "<{ dimLevelType = [ "; 122 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 123 switch (getDimLevelType()[i]) { 124 case DimLevelType::Dense: 125 printer << "\"dense\""; 126 break; 127 case DimLevelType::Compressed: 128 printer << "\"compressed\""; 129 break; 130 case DimLevelType::Singleton: 131 printer << "\"singleton\""; 132 break; 133 } 134 if (i != e - 1) 135 printer << ", "; 136 } 137 printer << " ]"; 138 if (getDimOrdering()) 139 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 140 printer << ", pointerBitWidth = " << getPointerBitWidth() 141 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 142 } 143 144 LogicalResult SparseTensorEncodingAttr::verify( 145 function_ref<InFlightDiagnostic()> emitError, 146 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 147 unsigned pointerBitWidth, unsigned indexBitWidth) { 148 if (!acceptBitWidth(pointerBitWidth)) 149 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 150 if (!acceptBitWidth(indexBitWidth)) 151 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 152 if (dimOrdering) { 153 if (!dimOrdering.isPermutation()) 154 return emitError() 155 << "expected a permutation affine map for dimension ordering"; 156 if (dimOrdering.getNumResults() != dimLevelType.size()) 157 return emitError() << "unexpected mismatch in ordering and dimension " 158 "level types size"; 159 } 160 return success(); 161 } 162 163 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 164 ArrayRef<int64_t> shape, Type elementType, 165 function_ref<InFlightDiagnostic()> emitError) const { 166 // Check structural integrity. 167 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 168 getPointerBitWidth(), getIndexBitWidth()))) 169 return failure(); 170 // Check integrity with tensor type specifics. Dimension ordering is optional, 171 // but we always should have dimension level types for the full rank. 172 unsigned size = shape.size(); 173 if (size == 0) 174 return emitError() << "expected non-scalar sparse tensor"; 175 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 176 return emitError() << "expected an affine map of size " << size 177 << " for dimension ordering"; 178 if (getDimLevelType().size() != size) 179 return emitError() << "expected an array of size " << size 180 << " for dimension level types"; 181 return success(); 182 } 183 184 SparseTensorEncodingAttr 185 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 186 if (auto ttp = type.dyn_cast<RankedTensorType>()) 187 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 188 return nullptr; 189 } 190 191 //===----------------------------------------------------------------------===// 192 // TensorDialect Operations. 193 //===----------------------------------------------------------------------===// 194 195 static LogicalResult isInBounds(Value dim, Value tensor) { 196 IntegerAttr constantAttr; 197 if (matchPattern(dim, m_Constant(&constantAttr))) { 198 unsigned d = constantAttr.getInt(); 199 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 200 return failure(); 201 } 202 return success(); // in bounds, or symbolic 203 } 204 205 static LogicalResult isMatchingWidth(Value result, unsigned width) { 206 Type etp = result.getType().cast<MemRefType>().getElementType(); 207 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 208 return success(); 209 return failure(); 210 } 211 212 LogicalResult NewOp::verify() { 213 if (!getSparseTensorEncoding(result().getType())) 214 return emitError("expected a sparse tensor result"); 215 return success(); 216 } 217 218 LogicalResult InitOp::verify() { 219 if (!getSparseTensorEncoding(result().getType())) 220 return emitError("expected a sparse tensor result"); 221 RankedTensorType ttp = getType().cast<RankedTensorType>(); 222 unsigned rank = ttp.getRank(); 223 if (rank != sizes().size()) 224 return emitError("unexpected mismatch between tensor rank and sizes: ") 225 << rank << " vs. " << sizes().size(); 226 auto shape = ttp.getShape(); 227 for (unsigned i = 0; i < rank; i++) { 228 if (shape[i] == ShapedType::kDynamicSize) 229 continue; 230 IntegerAttr constantAttr; 231 if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) || 232 constantAttr.getInt() != shape[i]) { 233 return emitError("unexpected mismatch with static dimension size ") 234 << shape[i]; 235 } 236 } 237 return success(); 238 } 239 240 LogicalResult ConvertOp::verify() { 241 if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) { 242 if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) { 243 if (tp1.getRank() != tp2.getRank()) 244 return emitError("unexpected conversion mismatch in rank"); 245 auto shape1 = tp1.getShape(); 246 auto shape2 = tp2.getShape(); 247 // Accept size matches between the source and the destination type 248 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 249 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 250 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) 251 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 252 return emitError("unexpected conversion mismatch in dimension ") << d; 253 return success(); 254 } 255 } 256 return emitError("unexpected type in convert"); 257 } 258 259 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 260 if (getType() == source().getType()) 261 return source(); 262 return {}; 263 } 264 265 LogicalResult ToPointersOp::verify() { 266 if (auto e = getSparseTensorEncoding(tensor().getType())) { 267 if (failed(isInBounds(dim(), tensor()))) 268 return emitError("requested pointers dimension out of bounds"); 269 if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) 270 return emitError("unexpected type for pointers"); 271 return success(); 272 } 273 return emitError("expected a sparse tensor to get pointers"); 274 } 275 276 LogicalResult ToIndicesOp::verify() { 277 if (auto e = getSparseTensorEncoding(tensor().getType())) { 278 if (failed(isInBounds(dim(), tensor()))) 279 return emitError("requested indices dimension out of bounds"); 280 if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) 281 return emitError("unexpected type for indices"); 282 return success(); 283 } 284 return emitError("expected a sparse tensor to get indices"); 285 } 286 287 LogicalResult ToValuesOp::verify() { 288 if (!getSparseTensorEncoding(tensor().getType())) 289 return emitError("expected a sparse tensor to get values"); 290 RankedTensorType ttp = tensor().getType().cast<RankedTensorType>(); 291 MemRefType mtp = result().getType().cast<MemRefType>(); 292 if (ttp.getElementType() != mtp.getElementType()) 293 return emitError("unexpected mismatch in element types"); 294 return success(); 295 } 296 297 //===----------------------------------------------------------------------===// 298 // TensorDialect Management Operations. 299 //===----------------------------------------------------------------------===// 300 301 LogicalResult LexInsertOp::verify() { 302 if (!getSparseTensorEncoding(tensor().getType())) 303 return emitError("expected a sparse tensor for insertion"); 304 return success(); 305 } 306 307 LogicalResult ExpandOp::verify() { 308 if (!getSparseTensorEncoding(tensor().getType())) 309 return emitError("expected a sparse tensor for expansion"); 310 return success(); 311 } 312 313 LogicalResult CompressOp::verify() { 314 if (!getSparseTensorEncoding(tensor().getType())) 315 return emitError("expected a sparse tensor for compression"); 316 return success(); 317 } 318 319 LogicalResult LoadOp::verify() { 320 if (!getSparseTensorEncoding(tensor().getType())) 321 return emitError("expected a sparse tensor to materialize"); 322 return success(); 323 } 324 325 LogicalResult ReleaseOp::verify() { 326 if (!getSparseTensorEncoding(tensor().getType())) 327 return emitError("expected a sparse tensor to release"); 328 return success(); 329 } 330 331 LogicalResult OutOp::verify() { 332 if (!getSparseTensorEncoding(tensor().getType())) 333 return emitError("expected a sparse tensor for output"); 334 return success(); 335 } 336 337 //===----------------------------------------------------------------------===// 338 // TensorDialect Methods. 339 //===----------------------------------------------------------------------===// 340 341 void SparseTensorDialect::initialize() { 342 addAttributes< 343 #define GET_ATTRDEF_LIST 344 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 345 >(); 346 addOperations< 347 #define GET_OP_LIST 348 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 349 >(); 350 } 351 352 #define GET_OP_CLASSES 353 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 354