1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 //===----------------------------------------------------------------------===// 21 // TensorDialect Attribute Methods. 22 //===----------------------------------------------------------------------===// 23 24 #define GET_ATTRDEF_CLASSES 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 26 27 static bool acceptBitWidth(unsigned bitWidth) { 28 switch (bitWidth) { 29 case 0: 30 case 8: 31 case 16: 32 case 32: 33 case 64: 34 return true; 35 default: 36 return false; 37 } 38 } 39 40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 41 if (failed(parser.parseLess())) 42 return {}; 43 // Parse the data as a dictionary. 44 DictionaryAttr dict; 45 if (failed(parser.parseAttribute(dict))) 46 return {}; 47 if (failed(parser.parseGreater())) 48 return {}; 49 // Process the data from the parsed dictionary value into struct-like data. 50 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 51 AffineMap map = {}; 52 unsigned ptr = 0; 53 unsigned ind = 0; 54 for (const NamedAttribute &attr : dict) { 55 if (attr.getName() == "dimLevelType") { 56 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 57 if (!arrayAttr) { 58 parser.emitError(parser.getNameLoc(), 59 "expected an array for dimension level types"); 60 return {}; 61 } 62 for (auto i : arrayAttr) { 63 auto strAttr = i.dyn_cast<StringAttr>(); 64 if (!strAttr) { 65 parser.emitError(parser.getNameLoc(), 66 "expected a string value in dimension level types"); 67 return {}; 68 } 69 auto strVal = strAttr.getValue(); 70 if (strVal == "dense") { 71 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 72 } else if (strVal == "compressed") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 74 } else if (strVal == "singleton") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 76 } else { 77 parser.emitError(parser.getNameLoc(), 78 "unexpected dimension level type: ") 79 << strVal; 80 return {}; 81 } 82 } 83 } else if (attr.getName() == "dimOrdering") { 84 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 85 if (!affineAttr) { 86 parser.emitError(parser.getNameLoc(), 87 "expected an affine map for dimension ordering"); 88 return {}; 89 } 90 map = affineAttr.getValue(); 91 } else if (attr.getName() == "pointerBitWidth") { 92 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 93 if (!intAttr) { 94 parser.emitError(parser.getNameLoc(), 95 "expected an integral pointer bitwidth"); 96 return {}; 97 } 98 ptr = intAttr.getInt(); 99 } else if (attr.getName() == "indexBitWidth") { 100 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 101 if (!intAttr) { 102 parser.emitError(parser.getNameLoc(), 103 "expected an integral index bitwidth"); 104 return {}; 105 } 106 ind = intAttr.getInt(); 107 } else { 108 parser.emitError(parser.getNameLoc(), "unexpected key: ") 109 << attr.getName().strref(); 110 return {}; 111 } 112 } 113 // Construct struct-like storage for attribute. 114 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 115 map, ptr, ind); 116 } 117 118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 119 // Print the struct-like storage in dictionary fashion. 120 printer << "<{ dimLevelType = [ "; 121 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 122 switch (getDimLevelType()[i]) { 123 case DimLevelType::Dense: 124 printer << "\"dense\""; 125 break; 126 case DimLevelType::Compressed: 127 printer << "\"compressed\""; 128 break; 129 case DimLevelType::Singleton: 130 printer << "\"singleton\""; 131 break; 132 } 133 if (i != e - 1) 134 printer << ", "; 135 } 136 printer << " ]"; 137 if (getDimOrdering()) 138 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 139 printer << ", pointerBitWidth = " << getPointerBitWidth() 140 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 141 } 142 143 LogicalResult SparseTensorEncodingAttr::verify( 144 function_ref<InFlightDiagnostic()> emitError, 145 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 146 unsigned pointerBitWidth, unsigned indexBitWidth) { 147 if (!acceptBitWidth(pointerBitWidth)) 148 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 149 if (!acceptBitWidth(indexBitWidth)) 150 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 151 if (dimOrdering) { 152 if (!dimOrdering.isPermutation()) 153 return emitError() 154 << "expected a permutation affine map for dimension ordering"; 155 if (dimOrdering.getNumResults() != dimLevelType.size()) 156 return emitError() << "unexpected mismatch in ordering and dimension " 157 "level types size"; 158 } 159 return success(); 160 } 161 162 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 163 ArrayRef<int64_t> shape, Type elementType, 164 function_ref<InFlightDiagnostic()> emitError) const { 165 // Check structural integrity. 166 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 167 getPointerBitWidth(), getIndexBitWidth()))) 168 return failure(); 169 // Check integrity with tensor type specifics. Dimension ordering is optional, 170 // but we always should have dimension level types for the full rank. 171 unsigned size = shape.size(); 172 if (size == 0) 173 return emitError() << "expected non-scalar sparse tensor"; 174 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 175 return emitError() << "expected an affine map of size " << size 176 << " for dimension ordering"; 177 if (getDimLevelType().size() != size) 178 return emitError() << "expected an array of size " << size 179 << " for dimension level types"; 180 return success(); 181 } 182 183 SparseTensorEncodingAttr 184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 185 if (auto ttp = type.dyn_cast<RankedTensorType>()) 186 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 187 return nullptr; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TensorDialect Operations. 192 //===----------------------------------------------------------------------===// 193 194 static LogicalResult isInBounds(Value dim, Value tensor) { 195 IntegerAttr constantAttr; 196 if (matchPattern(dim, m_Constant(&constantAttr))) { 197 unsigned d = constantAttr.getInt(); 198 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 199 return failure(); 200 } 201 return success(); // in bounds, or symbolic 202 } 203 204 static LogicalResult isMatchingWidth(Value result, unsigned width) { 205 Type etp = result.getType().cast<MemRefType>().getElementType(); 206 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 207 return success(); 208 return failure(); 209 } 210 211 LogicalResult ConvertOp::verify() { 212 if (auto tp1 = getSource().getType().dyn_cast<RankedTensorType>()) { 213 if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) { 214 if (tp1.getRank() != tp2.getRank()) 215 return emitError("unexpected conversion mismatch in rank"); 216 auto shape1 = tp1.getShape(); 217 auto shape2 = tp2.getShape(); 218 // Accept size matches between the source and the destination type 219 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 220 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 221 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) 222 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 223 return emitError("unexpected conversion mismatch in dimension ") << d; 224 return success(); 225 } 226 } 227 return emitError("unexpected type in convert"); 228 } 229 230 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 231 if (getType() == getSource().getType()) 232 return getSource(); 233 return {}; 234 } 235 236 LogicalResult ToPointersOp::verify() { 237 auto e = getSparseTensorEncoding(getTensor().getType()); 238 if (failed(isInBounds(getDim(), getTensor()))) 239 return emitError("requested pointers dimension out of bounds"); 240 if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth()))) 241 return emitError("unexpected type for pointers"); 242 return success(); 243 } 244 245 LogicalResult ToIndicesOp::verify() { 246 auto e = getSparseTensorEncoding(getTensor().getType()); 247 if (failed(isInBounds(getDim(), getTensor()))) 248 return emitError("requested indices dimension out of bounds"); 249 if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth()))) 250 return emitError("unexpected type for indices"); 251 return success(); 252 } 253 254 LogicalResult ToValuesOp::verify() { 255 RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>(); 256 MemRefType mtp = getResult().getType().cast<MemRefType>(); 257 if (ttp.getElementType() != mtp.getElementType()) 258 return emitError("unexpected mismatch in element types"); 259 return success(); 260 } 261 262 //===----------------------------------------------------------------------===// 263 // TensorDialect Linalg.Generic Operations. 264 //===----------------------------------------------------------------------===// 265 266 template <class T> 267 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, 268 const char *regionName, 269 TypeRange inputTypes, Type outputType) { 270 unsigned numArgs = region.getNumArguments(); 271 unsigned expectedNum = inputTypes.size(); 272 if (numArgs != expectedNum) 273 return op->emitError() << regionName << " region must have exactly " 274 << expectedNum << " arguments"; 275 276 for (unsigned i = 0; i < numArgs; i++) { 277 Type typ = region.getArgument(i).getType(); 278 if (typ != inputTypes[i]) 279 return op->emitError() << regionName << " region argument " << (i + 1) 280 << " type mismatch"; 281 } 282 Operation *term = region.front().getTerminator(); 283 YieldOp yield = dyn_cast<YieldOp>(term); 284 if (!yield) 285 return op->emitError() << regionName 286 << " region must end with sparse_tensor.yield"; 287 if (yield.getOperand().getType() != outputType) 288 return op->emitError() << regionName << " region yield type mismatch"; 289 290 return success(); 291 } 292 293 LogicalResult BinaryOp::verify() { 294 NamedAttrList attrs = (*this)->getAttrs(); 295 Type leftType = getX().getType(); 296 Type rightType = getY().getType(); 297 Type outputType = getOutput().getType(); 298 Region &overlap = getOverlapRegion(); 299 Region &left = getLeftRegion(); 300 Region &right = getRightRegion(); 301 302 // Check correct number of block arguments and return type for each 303 // non-empty region. 304 LogicalResult regionResult = success(); 305 if (!overlap.empty()) { 306 regionResult = verifyNumBlockArgs( 307 this, overlap, "overlap", TypeRange{leftType, rightType}, outputType); 308 if (failed(regionResult)) 309 return regionResult; 310 } 311 if (!left.empty()) { 312 regionResult = 313 verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType); 314 if (failed(regionResult)) 315 return regionResult; 316 } else if (getLeftIdentity()) { 317 if (leftType != outputType) 318 return emitError("left=identity requires first argument to have the same " 319 "type as the output"); 320 } 321 if (!right.empty()) { 322 regionResult = verifyNumBlockArgs(this, right, "right", 323 TypeRange{rightType}, outputType); 324 if (failed(regionResult)) 325 return regionResult; 326 } else if (getRightIdentity()) { 327 if (rightType != outputType) 328 return emitError("right=identity requires second argument to have the " 329 "same type as the output"); 330 } 331 332 return success(); 333 } 334 335 LogicalResult UnaryOp::verify() { 336 Type inputType = getX().getType(); 337 Type outputType = getOutput().getType(); 338 LogicalResult regionResult = success(); 339 340 // Check correct number of block arguments and return type for each 341 // non-empty region. 342 Region &present = getPresentRegion(); 343 if (!present.empty()) { 344 regionResult = verifyNumBlockArgs(this, present, "present", 345 TypeRange{inputType}, outputType); 346 if (failed(regionResult)) 347 return regionResult; 348 } 349 Region &absent = getAbsentRegion(); 350 if (!absent.empty()) { 351 regionResult = 352 verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType); 353 if (failed(regionResult)) 354 return regionResult; 355 } 356 357 return success(); 358 } 359 360 LogicalResult ReduceOp::verify() { 361 Type inputType = getX().getType(); 362 LogicalResult regionResult = success(); 363 364 // Check correct number of block arguments and return type. 365 Region &formula = getRegion(); 366 if (!formula.empty()) { 367 regionResult = verifyNumBlockArgs( 368 this, formula, "reduce", TypeRange{inputType, inputType}, inputType); 369 if (failed(regionResult)) 370 return regionResult; 371 } 372 373 return success(); 374 } 375 376 LogicalResult YieldOp::verify() { 377 // Check for compatible parent. 378 auto *parentOp = (*this)->getParentOp(); 379 if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) || 380 isa<ReduceOp>(parentOp)) 381 return success(); 382 383 return emitOpError( 384 "expected parent op to be sparse_tensor unary, binary, or reduce"); 385 } 386 387 //===----------------------------------------------------------------------===// 388 // TensorDialect Methods. 389 //===----------------------------------------------------------------------===// 390 391 void SparseTensorDialect::initialize() { 392 addAttributes< 393 #define GET_ATTRDEF_LIST 394 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 395 >(); 396 addOperations< 397 #define GET_OP_LIST 398 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 399 >(); 400 } 401 402 #define GET_OP_CLASSES 403 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 404 405 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 406