1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // TensorDialect Attribute Methods. 24 //===----------------------------------------------------------------------===// 25 26 #define GET_ATTRDEF_CLASSES 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 28 29 static bool acceptBitWidth(unsigned bitWidth) { 30 switch (bitWidth) { 31 case 0: 32 case 8: 33 case 16: 34 case 32: 35 case 64: 36 return true; 37 default: 38 return false; 39 } 40 } 41 42 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 43 if (failed(parser.parseLess())) 44 return {}; 45 // Parse the data as a dictionary. 46 DictionaryAttr dict; 47 if (failed(parser.parseAttribute(dict))) 48 return {}; 49 if (failed(parser.parseGreater())) 50 return {}; 51 // Process the data from the parsed dictionary value into struct-like data. 52 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 53 AffineMap map = {}; 54 unsigned ptr = 0; 55 unsigned ind = 0; 56 for (const NamedAttribute &attr : dict) { 57 if (attr.getName() == "dimLevelType") { 58 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 59 if (!arrayAttr) { 60 parser.emitError(parser.getNameLoc(), 61 "expected an array for dimension level types"); 62 return {}; 63 } 64 for (auto i : arrayAttr) { 65 auto strAttr = i.dyn_cast<StringAttr>(); 66 if (!strAttr) { 67 parser.emitError(parser.getNameLoc(), 68 "expected a string value in dimension level types"); 69 return {}; 70 } 71 auto strVal = strAttr.getValue(); 72 if (strVal == "dense") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 74 } else if (strVal == "compressed") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 76 } else if (strVal == "singleton") { 77 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 78 } else { 79 parser.emitError(parser.getNameLoc(), 80 "unexpected dimension level type: ") 81 << strVal; 82 return {}; 83 } 84 } 85 } else if (attr.getName() == "dimOrdering") { 86 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 87 if (!affineAttr) { 88 parser.emitError(parser.getNameLoc(), 89 "expected an affine map for dimension ordering"); 90 return {}; 91 } 92 map = affineAttr.getValue(); 93 } else if (attr.getName() == "pointerBitWidth") { 94 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 95 if (!intAttr) { 96 parser.emitError(parser.getNameLoc(), 97 "expected an integral pointer bitwidth"); 98 return {}; 99 } 100 ptr = intAttr.getInt(); 101 } else if (attr.getName() == "indexBitWidth") { 102 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 103 if (!intAttr) { 104 parser.emitError(parser.getNameLoc(), 105 "expected an integral index bitwidth"); 106 return {}; 107 } 108 ind = intAttr.getInt(); 109 } else { 110 parser.emitError(parser.getNameLoc(), "unexpected key: ") 111 << attr.getName().strref(); 112 return {}; 113 } 114 } 115 // Construct struct-like storage for attribute. 116 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 117 map, ptr, ind); 118 } 119 120 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 121 // Print the struct-like storage in dictionary fashion. 122 printer << "<{ dimLevelType = [ "; 123 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 124 switch (getDimLevelType()[i]) { 125 case DimLevelType::Dense: 126 printer << "\"dense\""; 127 break; 128 case DimLevelType::Compressed: 129 printer << "\"compressed\""; 130 break; 131 case DimLevelType::Singleton: 132 printer << "\"singleton\""; 133 break; 134 } 135 if (i != e - 1) 136 printer << ", "; 137 } 138 printer << " ]"; 139 if (getDimOrdering()) 140 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 141 printer << ", pointerBitWidth = " << getPointerBitWidth() 142 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 143 } 144 145 LogicalResult SparseTensorEncodingAttr::verify( 146 function_ref<InFlightDiagnostic()> emitError, 147 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 148 unsigned pointerBitWidth, unsigned indexBitWidth) { 149 if (!acceptBitWidth(pointerBitWidth)) 150 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 151 if (!acceptBitWidth(indexBitWidth)) 152 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 153 if (dimOrdering) { 154 if (!dimOrdering.isPermutation()) 155 return emitError() 156 << "expected a permutation affine map for dimension ordering"; 157 if (dimOrdering.getNumResults() != dimLevelType.size()) 158 return emitError() << "unexpected mismatch in ordering and dimension " 159 "level types size"; 160 } 161 return success(); 162 } 163 164 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 165 ArrayRef<int64_t> shape, Type elementType, 166 function_ref<InFlightDiagnostic()> emitError) const { 167 // Check structural integrity. 168 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 169 getPointerBitWidth(), getIndexBitWidth()))) 170 return failure(); 171 // Check integrity with tensor type specifics. Dimension ordering is optional, 172 // but we always should have dimension level types for the full rank. 173 unsigned size = shape.size(); 174 if (size == 0) 175 return emitError() << "expected non-scalar sparse tensor"; 176 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 177 return emitError() << "expected an affine map of size " << size 178 << " for dimension ordering"; 179 if (getDimLevelType().size() != size) 180 return emitError() << "expected an array of size " << size 181 << " for dimension level types"; 182 return success(); 183 } 184 185 SparseTensorEncodingAttr 186 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 187 if (auto ttp = type.dyn_cast<RankedTensorType>()) 188 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 189 return nullptr; 190 } 191 192 //===----------------------------------------------------------------------===// 193 // TensorDialect Operations. 194 //===----------------------------------------------------------------------===// 195 196 static LogicalResult isInBounds(Value dim, Value tensor) { 197 if (auto constantOp = dim.getDefiningOp<arith::ConstantOp>()) { 198 unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt(); 199 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 200 return failure(); 201 } 202 return success(); // in bounds, or symbolic 203 } 204 205 static LogicalResult isMatchingWidth(Value result, unsigned width) { 206 Type etp = result.getType().cast<MemRefType>().getElementType(); 207 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 208 return success(); 209 return failure(); 210 } 211 212 static LogicalResult verify(NewOp op) { 213 if (!getSparseTensorEncoding(op.result().getType())) 214 return op.emitError("expected a sparse tensor result"); 215 return success(); 216 } 217 218 static LogicalResult verify(InitOp op) { 219 if (!getSparseTensorEncoding(op.result().getType())) 220 return op.emitError("expected a sparse tensor result"); 221 RankedTensorType ttp = op.getType().cast<RankedTensorType>(); 222 unsigned rank = ttp.getRank(); 223 if (rank != op.sizes().size()) 224 return op.emitError("unexpected mismatch between tensor rank and sizes: ") 225 << rank << " vs. " << op.sizes().size(); 226 auto shape = ttp.getShape(); 227 for (unsigned i = 0; i < rank; i++) { 228 if (shape[i] == ShapedType::kDynamicSize) 229 continue; 230 auto constantOp = op.sizes()[i].getDefiningOp<arith::ConstantOp>(); 231 if (!constantOp || 232 constantOp.getValue().cast<IntegerAttr>().getInt() != shape[i]) 233 return op.emitError("unexpected mismatch with static dimension size ") 234 << shape[i]; 235 } 236 return success(); 237 } 238 239 static LogicalResult verify(ConvertOp op) { 240 if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 241 if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 242 if (tp1.getRank() != tp2.getRank()) 243 return op.emitError("unexpected conversion mismatch in rank"); 244 auto shape1 = tp1.getShape(); 245 auto shape2 = tp2.getShape(); 246 // Accept size matches between the source and the destination type 247 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 248 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 249 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { 250 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 251 return op.emitError("unexpected conversion mismatch in dimension ") 252 << d; 253 } 254 return success(); 255 } 256 } 257 return op.emitError("unexpected type in convert"); 258 } 259 260 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 261 if (getType() == source().getType()) 262 return source(); 263 return {}; 264 } 265 266 static LogicalResult verify(ToPointersOp op) { 267 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 268 if (failed(isInBounds(op.dim(), op.tensor()))) 269 return op.emitError("requested pointers dimension out of bounds"); 270 if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 271 return op.emitError("unexpected type for pointers"); 272 return success(); 273 } 274 return op.emitError("expected a sparse tensor to get pointers"); 275 } 276 277 static LogicalResult verify(ToIndicesOp op) { 278 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 279 if (failed(isInBounds(op.dim(), op.tensor()))) 280 return op.emitError("requested indices dimension out of bounds"); 281 if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 282 return op.emitError("unexpected type for indices"); 283 return success(); 284 } 285 return op.emitError("expected a sparse tensor to get indices"); 286 } 287 288 static LogicalResult verify(ToValuesOp op) { 289 if (!getSparseTensorEncoding(op.tensor().getType())) 290 return op.emitError("expected a sparse tensor to get values"); 291 RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 292 MemRefType mtp = op.result().getType().cast<MemRefType>(); 293 if (ttp.getElementType() != mtp.getElementType()) 294 return op.emitError("unexpected mismatch in element types"); 295 return success(); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // TensorDialect Management Operations. 300 //===----------------------------------------------------------------------===// 301 302 static LogicalResult verify(LexInsertOp op) { 303 if (!getSparseTensorEncoding(op.tensor().getType())) 304 return op.emitError("expected a sparse tensor for insertion"); 305 return success(); 306 } 307 308 static LogicalResult verify(ExpandOp op) { 309 if (!getSparseTensorEncoding(op.tensor().getType())) 310 return op.emitError("expected a sparse tensor for expansion"); 311 return success(); 312 } 313 314 static LogicalResult verify(CompressOp op) { 315 if (!getSparseTensorEncoding(op.tensor().getType())) 316 return op.emitError("expected a sparse tensor for compression"); 317 return success(); 318 } 319 320 static LogicalResult verify(LoadOp op) { 321 if (!getSparseTensorEncoding(op.tensor().getType())) 322 return op.emitError("expected a sparse tensor to materialize"); 323 return success(); 324 } 325 326 static LogicalResult verify(ReleaseOp op) { 327 if (!getSparseTensorEncoding(op.tensor().getType())) 328 return op.emitError("expected a sparse tensor to release"); 329 return success(); 330 } 331 332 //===----------------------------------------------------------------------===// 333 // TensorDialect Methods. 334 //===----------------------------------------------------------------------===// 335 336 void SparseTensorDialect::initialize() { 337 addAttributes< 338 #define GET_ATTRDEF_LIST 339 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 340 >(); 341 addOperations< 342 #define GET_OP_LIST 343 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 344 >(); 345 } 346 347 #define GET_OP_CLASSES 348 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 349 350 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, 351 Type type) const { 352 StringRef attrTag; 353 if (failed(parser.parseKeyword(&attrTag))) 354 return Attribute(); 355 Attribute attr; 356 auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); 357 if (parseResult.hasValue()) 358 return attr; 359 parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); 360 return Attribute(); 361 } 362 363 void SparseTensorDialect::printAttribute(Attribute attr, 364 DialectAsmPrinter &printer) const { 365 if (succeeded(generatedAttributePrinter(attr, printer))) 366 return; 367 } 368