1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 #include "mlir/IR/Builders.h" 11 #include "mlir/IR/DialectImplementation.h" 12 #include "mlir/IR/Matchers.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "llvm/ADT/TypeSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // TensorDialect Attribute Methods. 23 //===----------------------------------------------------------------------===// 24 25 #define GET_ATTRDEF_CLASSES 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 27 28 static bool acceptBitWidth(unsigned bitWidth) { 29 switch (bitWidth) { 30 case 0: 31 case 8: 32 case 16: 33 case 32: 34 case 64: 35 return true; 36 default: 37 return false; 38 } 39 } 40 41 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 42 if (failed(parser.parseLess())) 43 return {}; 44 // Parse the data as a dictionary. 45 DictionaryAttr dict; 46 if (failed(parser.parseAttribute(dict))) 47 return {}; 48 if (failed(parser.parseGreater())) 49 return {}; 50 // Process the data from the parsed dictionary value into struct-like data. 51 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 52 AffineMap map = {}; 53 unsigned ptr = 0; 54 unsigned ind = 0; 55 for (const NamedAttribute &attr : dict) { 56 if (attr.getName() == "dimLevelType") { 57 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 58 if (!arrayAttr) { 59 parser.emitError(parser.getNameLoc(), 60 "expected an array for dimension level types"); 61 return {}; 62 } 63 for (auto i : arrayAttr) { 64 auto strAttr = i.dyn_cast<StringAttr>(); 65 if (!strAttr) { 66 parser.emitError(parser.getNameLoc(), 67 "expected a string value in dimension level types"); 68 return {}; 69 } 70 auto strVal = strAttr.getValue(); 71 if (strVal == "dense") { 72 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 73 } else if (strVal == "compressed") { 74 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 75 } else if (strVal == "singleton") { 76 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 77 } else { 78 parser.emitError(parser.getNameLoc(), 79 "unexpected dimension level type: ") 80 << strVal; 81 return {}; 82 } 83 } 84 } else if (attr.getName() == "dimOrdering") { 85 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 86 if (!affineAttr) { 87 parser.emitError(parser.getNameLoc(), 88 "expected an affine map for dimension ordering"); 89 return {}; 90 } 91 map = affineAttr.getValue(); 92 } else if (attr.getName() == "pointerBitWidth") { 93 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 94 if (!intAttr) { 95 parser.emitError(parser.getNameLoc(), 96 "expected an integral pointer bitwidth"); 97 return {}; 98 } 99 ptr = intAttr.getInt(); 100 } else if (attr.getName() == "indexBitWidth") { 101 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 102 if (!intAttr) { 103 parser.emitError(parser.getNameLoc(), 104 "expected an integral index bitwidth"); 105 return {}; 106 } 107 ind = intAttr.getInt(); 108 } else { 109 parser.emitError(parser.getNameLoc(), "unexpected key: ") 110 << attr.getName().strref(); 111 return {}; 112 } 113 } 114 // Construct struct-like storage for attribute. 115 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 116 map, ptr, ind); 117 } 118 119 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 120 // Print the struct-like storage in dictionary fashion. 121 printer << "<{ dimLevelType = [ "; 122 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 123 switch (getDimLevelType()[i]) { 124 case DimLevelType::Dense: 125 printer << "\"dense\""; 126 break; 127 case DimLevelType::Compressed: 128 printer << "\"compressed\""; 129 break; 130 case DimLevelType::Singleton: 131 printer << "\"singleton\""; 132 break; 133 } 134 if (i != e - 1) 135 printer << ", "; 136 } 137 printer << " ]"; 138 if (getDimOrdering()) 139 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 140 printer << ", pointerBitWidth = " << getPointerBitWidth() 141 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 142 } 143 144 LogicalResult SparseTensorEncodingAttr::verify( 145 function_ref<InFlightDiagnostic()> emitError, 146 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 147 unsigned pointerBitWidth, unsigned indexBitWidth) { 148 if (!acceptBitWidth(pointerBitWidth)) 149 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 150 if (!acceptBitWidth(indexBitWidth)) 151 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 152 if (dimOrdering) { 153 if (!dimOrdering.isPermutation()) 154 return emitError() 155 << "expected a permutation affine map for dimension ordering"; 156 if (dimOrdering.getNumResults() != dimLevelType.size()) 157 return emitError() << "unexpected mismatch in ordering and dimension " 158 "level types size"; 159 } 160 return success(); 161 } 162 163 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 164 ArrayRef<int64_t> shape, Type elementType, 165 function_ref<InFlightDiagnostic()> emitError) const { 166 // Check structural integrity. 167 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 168 getPointerBitWidth(), getIndexBitWidth()))) 169 return failure(); 170 // Check integrity with tensor type specifics. Dimension ordering is optional, 171 // but we always should have dimension level types for the full rank. 172 unsigned size = shape.size(); 173 if (size == 0) 174 return emitError() << "expected non-scalar sparse tensor"; 175 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 176 return emitError() << "expected an affine map of size " << size 177 << " for dimension ordering"; 178 if (getDimLevelType().size() != size) 179 return emitError() << "expected an array of size " << size 180 << " for dimension level types"; 181 return success(); 182 } 183 184 SparseTensorEncodingAttr 185 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 186 if (auto ttp = type.dyn_cast<RankedTensorType>()) 187 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 188 return nullptr; 189 } 190 191 //===----------------------------------------------------------------------===// 192 // TensorDialect Operations. 193 //===----------------------------------------------------------------------===// 194 195 static LogicalResult isInBounds(Value dim, Value tensor) { 196 IntegerAttr constantAttr; 197 if (matchPattern(dim, m_Constant(&constantAttr))) { 198 unsigned d = constantAttr.getInt(); 199 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 200 return failure(); 201 } 202 return success(); // in bounds, or symbolic 203 } 204 205 static LogicalResult isMatchingWidth(Value result, unsigned width) { 206 Type etp = result.getType().cast<MemRefType>().getElementType(); 207 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 208 return success(); 209 return failure(); 210 } 211 212 static LogicalResult verify(NewOp op) { 213 if (!getSparseTensorEncoding(op.result().getType())) 214 return op.emitError("expected a sparse tensor result"); 215 return success(); 216 } 217 218 static LogicalResult verify(InitOp op) { 219 if (!getSparseTensorEncoding(op.result().getType())) 220 return op.emitError("expected a sparse tensor result"); 221 RankedTensorType ttp = op.getType().cast<RankedTensorType>(); 222 unsigned rank = ttp.getRank(); 223 if (rank != op.sizes().size()) 224 return op.emitError("unexpected mismatch between tensor rank and sizes: ") 225 << rank << " vs. " << op.sizes().size(); 226 auto shape = ttp.getShape(); 227 for (unsigned i = 0; i < rank; i++) { 228 if (shape[i] == ShapedType::kDynamicSize) 229 continue; 230 IntegerAttr constantAttr; 231 if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) || 232 constantAttr.getInt() != shape[i]) { 233 return op.emitError("unexpected mismatch with static dimension size ") 234 << shape[i]; 235 } 236 } 237 return success(); 238 } 239 240 static LogicalResult verify(ConvertOp op) { 241 if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 242 if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 243 if (tp1.getRank() != tp2.getRank()) 244 return op.emitError("unexpected conversion mismatch in rank"); 245 auto shape1 = tp1.getShape(); 246 auto shape2 = tp2.getShape(); 247 // Accept size matches between the source and the destination type 248 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 249 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 250 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { 251 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 252 return op.emitError("unexpected conversion mismatch in dimension ") 253 << d; 254 } 255 return success(); 256 } 257 } 258 return op.emitError("unexpected type in convert"); 259 } 260 261 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 262 if (getType() == source().getType()) 263 return source(); 264 return {}; 265 } 266 267 static LogicalResult verify(ToPointersOp op) { 268 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 269 if (failed(isInBounds(op.dim(), op.tensor()))) 270 return op.emitError("requested pointers dimension out of bounds"); 271 if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 272 return op.emitError("unexpected type for pointers"); 273 return success(); 274 } 275 return op.emitError("expected a sparse tensor to get pointers"); 276 } 277 278 static LogicalResult verify(ToIndicesOp op) { 279 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 280 if (failed(isInBounds(op.dim(), op.tensor()))) 281 return op.emitError("requested indices dimension out of bounds"); 282 if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 283 return op.emitError("unexpected type for indices"); 284 return success(); 285 } 286 return op.emitError("expected a sparse tensor to get indices"); 287 } 288 289 static LogicalResult verify(ToValuesOp op) { 290 if (!getSparseTensorEncoding(op.tensor().getType())) 291 return op.emitError("expected a sparse tensor to get values"); 292 RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 293 MemRefType mtp = op.result().getType().cast<MemRefType>(); 294 if (ttp.getElementType() != mtp.getElementType()) 295 return op.emitError("unexpected mismatch in element types"); 296 return success(); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // TensorDialect Management Operations. 301 //===----------------------------------------------------------------------===// 302 303 static LogicalResult verify(LexInsertOp op) { 304 if (!getSparseTensorEncoding(op.tensor().getType())) 305 return op.emitError("expected a sparse tensor for insertion"); 306 return success(); 307 } 308 309 static LogicalResult verify(ExpandOp op) { 310 if (!getSparseTensorEncoding(op.tensor().getType())) 311 return op.emitError("expected a sparse tensor for expansion"); 312 return success(); 313 } 314 315 static LogicalResult verify(CompressOp op) { 316 if (!getSparseTensorEncoding(op.tensor().getType())) 317 return op.emitError("expected a sparse tensor for compression"); 318 return success(); 319 } 320 321 static LogicalResult verify(LoadOp op) { 322 if (!getSparseTensorEncoding(op.tensor().getType())) 323 return op.emitError("expected a sparse tensor to materialize"); 324 return success(); 325 } 326 327 static LogicalResult verify(ReleaseOp op) { 328 if (!getSparseTensorEncoding(op.tensor().getType())) 329 return op.emitError("expected a sparse tensor to release"); 330 return success(); 331 } 332 333 static LogicalResult verify(OutOp op) { 334 if (!getSparseTensorEncoding(op.tensor().getType())) 335 return op.emitError("expected a sparse tensor for output"); 336 return success(); 337 } 338 339 //===----------------------------------------------------------------------===// 340 // TensorDialect Methods. 341 //===----------------------------------------------------------------------===// 342 343 void SparseTensorDialect::initialize() { 344 addAttributes< 345 #define GET_ATTRDEF_LIST 346 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 347 >(); 348 addOperations< 349 #define GET_OP_LIST 350 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 351 >(); 352 } 353 354 #define GET_OP_CLASSES 355 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 356