1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/Matchers.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 //===----------------------------------------------------------------------===// 21 // TensorDialect Attribute Methods. 22 //===----------------------------------------------------------------------===// 23 24 #define GET_ATTRDEF_CLASSES 25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 26 27 static bool acceptBitWidth(unsigned bitWidth) { 28 switch (bitWidth) { 29 case 0: 30 case 8: 31 case 16: 32 case 32: 33 case 64: 34 return true; 35 default: 36 return false; 37 } 38 } 39 40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 41 if (failed(parser.parseLess())) 42 return {}; 43 // Parse the data as a dictionary. 44 DictionaryAttr dict; 45 if (failed(parser.parseAttribute(dict))) 46 return {}; 47 if (failed(parser.parseGreater())) 48 return {}; 49 // Process the data from the parsed dictionary value into struct-like data. 50 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 51 AffineMap map = {}; 52 unsigned ptr = 0; 53 unsigned ind = 0; 54 for (const NamedAttribute &attr : dict) { 55 if (attr.getName() == "dimLevelType") { 56 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 57 if (!arrayAttr) { 58 parser.emitError(parser.getNameLoc(), 59 "expected an array for dimension level types"); 60 return {}; 61 } 62 for (auto i : arrayAttr) { 63 auto strAttr = i.dyn_cast<StringAttr>(); 64 if (!strAttr) { 65 parser.emitError(parser.getNameLoc(), 66 "expected a string value in dimension level types"); 67 return {}; 68 } 69 auto strVal = strAttr.getValue(); 70 if (strVal == "dense") { 71 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 72 } else if (strVal == "compressed") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 74 } else if (strVal == "singleton") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 76 } else { 77 parser.emitError(parser.getNameLoc(), 78 "unexpected dimension level type: ") 79 << strVal; 80 return {}; 81 } 82 } 83 } else if (attr.getName() == "dimOrdering") { 84 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 85 if (!affineAttr) { 86 parser.emitError(parser.getNameLoc(), 87 "expected an affine map for dimension ordering"); 88 return {}; 89 } 90 map = affineAttr.getValue(); 91 } else if (attr.getName() == "pointerBitWidth") { 92 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 93 if (!intAttr) { 94 parser.emitError(parser.getNameLoc(), 95 "expected an integral pointer bitwidth"); 96 return {}; 97 } 98 ptr = intAttr.getInt(); 99 } else if (attr.getName() == "indexBitWidth") { 100 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 101 if (!intAttr) { 102 parser.emitError(parser.getNameLoc(), 103 "expected an integral index bitwidth"); 104 return {}; 105 } 106 ind = intAttr.getInt(); 107 } else { 108 parser.emitError(parser.getNameLoc(), "unexpected key: ") 109 << attr.getName().strref(); 110 return {}; 111 } 112 } 113 // Construct struct-like storage for attribute. 114 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 115 map, ptr, ind); 116 } 117 118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 119 // Print the struct-like storage in dictionary fashion. 120 printer << "<{ dimLevelType = [ "; 121 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 122 switch (getDimLevelType()[i]) { 123 case DimLevelType::Dense: 124 printer << "\"dense\""; 125 break; 126 case DimLevelType::Compressed: 127 printer << "\"compressed\""; 128 break; 129 case DimLevelType::Singleton: 130 printer << "\"singleton\""; 131 break; 132 } 133 if (i != e - 1) 134 printer << ", "; 135 } 136 printer << " ]"; 137 if (getDimOrdering()) 138 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 139 printer << ", pointerBitWidth = " << getPointerBitWidth() 140 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 141 } 142 143 LogicalResult SparseTensorEncodingAttr::verify( 144 function_ref<InFlightDiagnostic()> emitError, 145 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 146 unsigned pointerBitWidth, unsigned indexBitWidth) { 147 if (!acceptBitWidth(pointerBitWidth)) 148 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 149 if (!acceptBitWidth(indexBitWidth)) 150 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 151 if (dimOrdering) { 152 if (!dimOrdering.isPermutation()) 153 return emitError() 154 << "expected a permutation affine map for dimension ordering"; 155 if (dimOrdering.getNumResults() != dimLevelType.size()) 156 return emitError() << "unexpected mismatch in ordering and dimension " 157 "level types size"; 158 } 159 return success(); 160 } 161 162 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 163 ArrayRef<int64_t> shape, Type elementType, 164 function_ref<InFlightDiagnostic()> emitError) const { 165 // Check structural integrity. 166 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 167 getPointerBitWidth(), getIndexBitWidth()))) 168 return failure(); 169 // Check integrity with tensor type specifics. Dimension ordering is optional, 170 // but we always should have dimension level types for the full rank. 171 unsigned size = shape.size(); 172 if (size == 0) 173 return emitError() << "expected non-scalar sparse tensor"; 174 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 175 return emitError() << "expected an affine map of size " << size 176 << " for dimension ordering"; 177 if (getDimLevelType().size() != size) 178 return emitError() << "expected an array of size " << size 179 << " for dimension level types"; 180 return success(); 181 } 182 183 SparseTensorEncodingAttr 184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 185 if (auto ttp = type.dyn_cast<RankedTensorType>()) 186 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 187 return nullptr; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TensorDialect Operations. 192 //===----------------------------------------------------------------------===// 193 194 static LogicalResult isInBounds(Value dim, Value tensor) { 195 IntegerAttr constantAttr; 196 if (matchPattern(dim, m_Constant(&constantAttr))) { 197 unsigned d = constantAttr.getInt(); 198 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 199 return failure(); 200 } 201 return success(); // in bounds, or symbolic 202 } 203 204 static LogicalResult isMatchingWidth(Value result, unsigned width) { 205 Type etp = result.getType().cast<MemRefType>().getElementType(); 206 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 207 return success(); 208 return failure(); 209 } 210 211 LogicalResult NewOp::verify() { 212 if (!getSparseTensorEncoding(result().getType())) 213 return emitError("expected a sparse tensor result"); 214 return success(); 215 } 216 217 LogicalResult InitOp::verify() { 218 if (!getSparseTensorEncoding(result().getType())) 219 return emitError("expected a sparse tensor result"); 220 RankedTensorType ttp = getType().cast<RankedTensorType>(); 221 unsigned rank = ttp.getRank(); 222 if (rank != sizes().size()) 223 return emitError("unexpected mismatch between tensor rank and sizes: ") 224 << rank << " vs. " << sizes().size(); 225 auto shape = ttp.getShape(); 226 for (unsigned i = 0; i < rank; i++) { 227 if (shape[i] == ShapedType::kDynamicSize) 228 continue; 229 IntegerAttr constantAttr; 230 if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) || 231 constantAttr.getInt() != shape[i]) { 232 return emitError("unexpected mismatch with static dimension size ") 233 << shape[i]; 234 } 235 } 236 return success(); 237 } 238 239 LogicalResult ConvertOp::verify() { 240 if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) { 241 if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) { 242 if (tp1.getRank() != tp2.getRank()) 243 return emitError("unexpected conversion mismatch in rank"); 244 auto shape1 = tp1.getShape(); 245 auto shape2 = tp2.getShape(); 246 // Accept size matches between the source and the destination type 247 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 248 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 249 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) 250 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 251 return emitError("unexpected conversion mismatch in dimension ") << d; 252 return success(); 253 } 254 } 255 return emitError("unexpected type in convert"); 256 } 257 258 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 259 if (getType() == source().getType()) 260 return source(); 261 return {}; 262 } 263 264 LogicalResult ToPointersOp::verify() { 265 if (auto e = getSparseTensorEncoding(tensor().getType())) { 266 if (failed(isInBounds(dim(), tensor()))) 267 return emitError("requested pointers dimension out of bounds"); 268 if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) 269 return emitError("unexpected type for pointers"); 270 return success(); 271 } 272 return emitError("expected a sparse tensor to get pointers"); 273 } 274 275 LogicalResult ToIndicesOp::verify() { 276 if (auto e = getSparseTensorEncoding(tensor().getType())) { 277 if (failed(isInBounds(dim(), tensor()))) 278 return emitError("requested indices dimension out of bounds"); 279 if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) 280 return emitError("unexpected type for indices"); 281 return success(); 282 } 283 return emitError("expected a sparse tensor to get indices"); 284 } 285 286 LogicalResult ToValuesOp::verify() { 287 if (!getSparseTensorEncoding(tensor().getType())) 288 return emitError("expected a sparse tensor to get values"); 289 RankedTensorType ttp = tensor().getType().cast<RankedTensorType>(); 290 MemRefType mtp = result().getType().cast<MemRefType>(); 291 if (ttp.getElementType() != mtp.getElementType()) 292 return emitError("unexpected mismatch in element types"); 293 return success(); 294 } 295 296 //===----------------------------------------------------------------------===// 297 // TensorDialect Management Operations. 298 //===----------------------------------------------------------------------===// 299 300 LogicalResult LexInsertOp::verify() { 301 if (!getSparseTensorEncoding(tensor().getType())) 302 return emitError("expected a sparse tensor for insertion"); 303 return success(); 304 } 305 306 LogicalResult ExpandOp::verify() { 307 if (!getSparseTensorEncoding(tensor().getType())) 308 return emitError("expected a sparse tensor for expansion"); 309 return success(); 310 } 311 312 LogicalResult CompressOp::verify() { 313 if (!getSparseTensorEncoding(tensor().getType())) 314 return emitError("expected a sparse tensor for compression"); 315 return success(); 316 } 317 318 LogicalResult LoadOp::verify() { 319 if (!getSparseTensorEncoding(tensor().getType())) 320 return emitError("expected a sparse tensor to materialize"); 321 return success(); 322 } 323 324 LogicalResult ReleaseOp::verify() { 325 if (!getSparseTensorEncoding(tensor().getType())) 326 return emitError("expected a sparse tensor to release"); 327 return success(); 328 } 329 330 LogicalResult OutOp::verify() { 331 if (!getSparseTensorEncoding(tensor().getType())) 332 return emitError("expected a sparse tensor for output"); 333 return success(); 334 } 335 336 //===----------------------------------------------------------------------===// 337 // TensorDialect Methods. 338 //===----------------------------------------------------------------------===// 339 340 void SparseTensorDialect::initialize() { 341 addAttributes< 342 #define GET_ATTRDEF_LIST 343 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 344 >(); 345 addOperations< 346 #define GET_OP_LIST 347 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 348 >(); 349 } 350 351 #define GET_OP_CLASSES 352 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 353 354 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 355