1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 //===----------------------------------------------------------------------===// 21 // TensorDialect Attribute Methods. 22 //===----------------------------------------------------------------------===// 23 24 #define GET_ATTRDEF_CLASSES 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 26 27 static bool acceptBitWidth(unsigned bitWidth) { 28 switch (bitWidth) { 29 case 0: 30 case 8: 31 case 16: 32 case 32: 33 case 64: 34 return true; 35 default: 36 return false; 37 } 38 } 39 40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 41 if (failed(parser.parseLess())) 42 return {}; 43 // Parse the data as a dictionary. 44 DictionaryAttr dict; 45 if (failed(parser.parseAttribute(dict))) 46 return {}; 47 if (failed(parser.parseGreater())) 48 return {}; 49 // Process the data from the parsed dictionary value into struct-like data. 50 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 51 AffineMap map = {}; 52 unsigned ptr = 0; 53 unsigned ind = 0; 54 for (const NamedAttribute &attr : dict) { 55 if (attr.getName() == "dimLevelType") { 56 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 57 if (!arrayAttr) { 58 parser.emitError(parser.getNameLoc(), 59 "expected an array for dimension level types"); 60 return {}; 61 } 62 for (auto i : arrayAttr) { 63 auto strAttr = i.dyn_cast<StringAttr>(); 64 if (!strAttr) { 65 parser.emitError(parser.getNameLoc(), 66 "expected a string value in dimension level types"); 67 return {}; 68 } 69 auto strVal = strAttr.getValue(); 70 if (strVal == "dense") { 71 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 72 } else if (strVal == "compressed") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 74 } else if (strVal == "singleton") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 76 } else { 77 parser.emitError(parser.getNameLoc(), 78 "unexpected dimension level type: ") 79 << strVal; 80 return {}; 81 } 82 } 83 } else if (attr.getName() == "dimOrdering") { 84 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 85 if (!affineAttr) { 86 parser.emitError(parser.getNameLoc(), 87 "expected an affine map for dimension ordering"); 88 return {}; 89 } 90 map = affineAttr.getValue(); 91 } else if (attr.getName() == "pointerBitWidth") { 92 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 93 if (!intAttr) { 94 parser.emitError(parser.getNameLoc(), 95 "expected an integral pointer bitwidth"); 96 return {}; 97 } 98 ptr = intAttr.getInt(); 99 } else if (attr.getName() == "indexBitWidth") { 100 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 101 if (!intAttr) { 102 parser.emitError(parser.getNameLoc(), 103 "expected an integral index bitwidth"); 104 return {}; 105 } 106 ind = intAttr.getInt(); 107 } else { 108 parser.emitError(parser.getNameLoc(), "unexpected key: ") 109 << attr.getName().strref(); 110 return {}; 111 } 112 } 113 // Construct struct-like storage for attribute. 114 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 115 map, ptr, ind); 116 } 117 118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 119 // Print the struct-like storage in dictionary fashion. 120 printer << "<{ dimLevelType = [ "; 121 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 122 switch (getDimLevelType()[i]) { 123 case DimLevelType::Dense: 124 printer << "\"dense\""; 125 break; 126 case DimLevelType::Compressed: 127 printer << "\"compressed\""; 128 break; 129 case DimLevelType::Singleton: 130 printer << "\"singleton\""; 131 break; 132 } 133 if (i != e - 1) 134 printer << ", "; 135 } 136 printer << " ]"; 137 if (getDimOrdering()) 138 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 139 printer << ", pointerBitWidth = " << getPointerBitWidth() 140 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 141 } 142 143 LogicalResult SparseTensorEncodingAttr::verify( 144 function_ref<InFlightDiagnostic()> emitError, 145 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 146 unsigned pointerBitWidth, unsigned indexBitWidth) { 147 if (!acceptBitWidth(pointerBitWidth)) 148 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 149 if (!acceptBitWidth(indexBitWidth)) 150 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 151 if (dimOrdering) { 152 if (!dimOrdering.isPermutation()) 153 return emitError() 154 << "expected a permutation affine map for dimension ordering"; 155 if (dimOrdering.getNumResults() != dimLevelType.size()) 156 return emitError() << "unexpected mismatch in ordering and dimension " 157 "level types size"; 158 } 159 return success(); 160 } 161 162 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 163 ArrayRef<int64_t> shape, Type elementType, 164 function_ref<InFlightDiagnostic()> emitError) const { 165 // Check structural integrity. 166 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 167 getPointerBitWidth(), getIndexBitWidth()))) 168 return failure(); 169 // Check integrity with tensor type specifics. Dimension ordering is optional, 170 // but we always should have dimension level types for the full rank. 171 unsigned size = shape.size(); 172 if (size == 0) 173 return emitError() << "expected non-scalar sparse tensor"; 174 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 175 return emitError() << "expected an affine map of size " << size 176 << " for dimension ordering"; 177 if (getDimLevelType().size() != size) 178 return emitError() << "expected an array of size " << size 179 << " for dimension level types"; 180 return success(); 181 } 182 183 SparseTensorEncodingAttr 184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 185 if (auto ttp = type.dyn_cast<RankedTensorType>()) 186 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 187 return nullptr; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TensorDialect Operations. 192 //===----------------------------------------------------------------------===// 193 194 static LogicalResult isInBounds(Value dim, Value tensor) { 195 IntegerAttr constantAttr; 196 if (matchPattern(dim, m_Constant(&constantAttr))) { 197 unsigned d = constantAttr.getInt(); 198 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 199 return failure(); 200 } 201 return success(); // in bounds, or symbolic 202 } 203 204 static LogicalResult isMatchingWidth(Value result, unsigned width) { 205 Type etp = result.getType().cast<MemRefType>().getElementType(); 206 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 207 return success(); 208 return failure(); 209 } 210 211 LogicalResult NewOp::verify() { 212 if (!getSparseTensorEncoding(result().getType())) 213 return emitError("expected a sparse tensor result"); 214 return success(); 215 } 216 217 LogicalResult ConvertOp::verify() { 218 if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) { 219 if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) { 220 if (tp1.getRank() != tp2.getRank()) 221 return emitError("unexpected conversion mismatch in rank"); 222 auto shape1 = tp1.getShape(); 223 auto shape2 = tp2.getShape(); 224 // Accept size matches between the source and the destination type 225 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 226 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 227 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) 228 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 229 return emitError("unexpected conversion mismatch in dimension ") << d; 230 return success(); 231 } 232 } 233 return emitError("unexpected type in convert"); 234 } 235 236 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 237 if (getType() == source().getType()) 238 return source(); 239 return {}; 240 } 241 242 LogicalResult ToPointersOp::verify() { 243 if (auto e = getSparseTensorEncoding(tensor().getType())) { 244 if (failed(isInBounds(dim(), tensor()))) 245 return emitError("requested pointers dimension out of bounds"); 246 if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) 247 return emitError("unexpected type for pointers"); 248 return success(); 249 } 250 return emitError("expected a sparse tensor to get pointers"); 251 } 252 253 LogicalResult ToIndicesOp::verify() { 254 if (auto e = getSparseTensorEncoding(tensor().getType())) { 255 if (failed(isInBounds(dim(), tensor()))) 256 return emitError("requested indices dimension out of bounds"); 257 if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) 258 return emitError("unexpected type for indices"); 259 return success(); 260 } 261 return emitError("expected a sparse tensor to get indices"); 262 } 263 264 LogicalResult ToValuesOp::verify() { 265 if (!getSparseTensorEncoding(tensor().getType())) 266 return emitError("expected a sparse tensor to get values"); 267 RankedTensorType ttp = tensor().getType().cast<RankedTensorType>(); 268 MemRefType mtp = result().getType().cast<MemRefType>(); 269 if (ttp.getElementType() != mtp.getElementType()) 270 return emitError("unexpected mismatch in element types"); 271 return success(); 272 } 273 274 //===----------------------------------------------------------------------===// 275 // TensorDialect Management Operations. 276 //===----------------------------------------------------------------------===// 277 278 LogicalResult LexInsertOp::verify() { 279 if (!getSparseTensorEncoding(tensor().getType())) 280 return emitError("expected a sparse tensor for insertion"); 281 return success(); 282 } 283 284 LogicalResult ExpandOp::verify() { 285 if (!getSparseTensorEncoding(tensor().getType())) 286 return emitError("expected a sparse tensor for expansion"); 287 return success(); 288 } 289 290 LogicalResult CompressOp::verify() { 291 if (!getSparseTensorEncoding(tensor().getType())) 292 return emitError("expected a sparse tensor for compression"); 293 return success(); 294 } 295 296 LogicalResult LoadOp::verify() { 297 if (!getSparseTensorEncoding(tensor().getType())) 298 return emitError("expected a sparse tensor to materialize"); 299 return success(); 300 } 301 302 LogicalResult ReleaseOp::verify() { 303 if (!getSparseTensorEncoding(tensor().getType())) 304 return emitError("expected a sparse tensor to release"); 305 return success(); 306 } 307 308 LogicalResult OutOp::verify() { 309 if (!getSparseTensorEncoding(tensor().getType())) 310 return emitError("expected a sparse tensor for output"); 311 return success(); 312 } 313 314 //===----------------------------------------------------------------------===// 315 // TensorDialect Linalg.Generic Operations. 316 //===----------------------------------------------------------------------===// 317 318 template <class T> 319 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 320 const char *regionName, 321 TypeRange inputTypes, Type outputType) { 322 unsigned numArgs = region.getNumArguments(); 323 unsigned expectedNum = inputTypes.size(); 324 if (numArgs != expectedNum) 325 return op->emitError() << regionName << " region must have exactly " 326 << expectedNum << " arguments"; 327 328 for (unsigned i = 0; i < numArgs; i++) { 329 Type typ = region.getArgument(i).getType(); 330 if (typ != inputTypes[i]) 331 return op->emitError() << regionName << " region argument " << (i + 1) 332 << " type mismatch"; 333 } 334 Operation *term = region.front().getTerminator(); 335 YieldOp yield = dyn_cast<YieldOp>(term); 336 if (!yield) 337 return op->emitError() << regionName 338 << " region must end with sparse_tensor.yield"; 339 if (yield.getOperand().getType() != outputType) 340 return op->emitError() << regionName << " region yield type mismatch"; 341 342 return success(); 343 } 344 345 LogicalResult BinaryOp::verify() { 346 NamedAttrList attrs = (*this)->getAttrs(); 347 Type leftType = x().getType(); 348 Type rightType = y().getType(); 349 Type outputType = output().getType(); 350 Region &overlap = overlapRegion(); 351 Region &left = leftRegion(); 352 Region &right = rightRegion(); 353 354 // Check correct number of block arguments and return type for each 355 // non-empty region. 356 LogicalResult regionResult = success(); 357 if (!overlap.empty()) { 358 regionResult = verifyNumBlockArgs( 359 this, overlap, "overlap", TypeRange{leftType, rightType}, outputType); 360 if (failed(regionResult)) 361 return regionResult; 362 } 363 if (!left.empty()) { 364 regionResult = 365 verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType); 366 if (failed(regionResult)) 367 return regionResult; 368 } else if (left_identity()) { 369 if (leftType != outputType) 370 return emitError("left=identity requires first argument to have the same " 371 "type as the output"); 372 } 373 if (!right.empty()) { 374 regionResult = verifyNumBlockArgs(this, right, "right", 375 TypeRange{rightType}, outputType); 376 if (failed(regionResult)) 377 return regionResult; 378 } else if (right_identity()) { 379 if (rightType != outputType) 380 return emitError("right=identity requires second argument to have the " 381 "same type as the output"); 382 } 383 384 return success(); 385 } 386 387 LogicalResult UnaryOp::verify() { 388 Type inputType = x().getType(); 389 Type outputType = output().getType(); 390 LogicalResult regionResult = success(); 391 392 // Check correct number of block arguments and return type for each 393 // non-empty region. 394 Region &present = presentRegion(); 395 if (!present.empty()) { 396 regionResult = verifyNumBlockArgs(this, present, "present", 397 TypeRange{inputType}, outputType); 398 if (failed(regionResult)) 399 return regionResult; 400 } 401 Region &absent = absentRegion(); 402 if (!absent.empty()) { 403 regionResult = 404 verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType); 405 if (failed(regionResult)) 406 return regionResult; 407 } 408 409 return success(); 410 } 411 412 LogicalResult YieldOp::verify() { 413 // Check for compatible parent. 414 auto *parentOp = (*this)->getParentOp(); 415 if (auto binaryOp = dyn_cast<BinaryOp>(parentOp)) 416 return success(); 417 if (auto unaryOp = dyn_cast<UnaryOp>(parentOp)) 418 return success(); 419 420 return emitOpError("expected parent op to be sparse_tensor binary or unary"); 421 } 422 423 //===----------------------------------------------------------------------===// 424 // TensorDialect Methods. 425 //===----------------------------------------------------------------------===// 426 427 void SparseTensorDialect::initialize() { 428 addAttributes< 429 #define GET_ATTRDEF_LIST 430 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 431 >(); 432 addOperations< 433 #define GET_OP_LIST 434 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 435 >(); 436 } 437 438 #define GET_OP_CLASSES 439 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 440 441 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 442