1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10 #include "mlir/Dialect/StandardOps/IR/Ops.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/DialectImplementation.h" 13 #include "mlir/IR/OpImplementation.h" 14 #include "llvm/ADT/TypeSwitch.h" 15 16 using namespace mlir; 17 using namespace mlir::sparse_tensor; 18 19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 20 21 //===----------------------------------------------------------------------===// 22 // TensorDialect Attribute Methods. 23 //===----------------------------------------------------------------------===// 24 25 #define GET_ATTRDEF_CLASSES 26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 27 28 static bool acceptBitWidth(unsigned bitWidth) { 29 switch (bitWidth) { 30 case 0: 31 case 8: 32 case 16: 33 case 32: 34 case 64: 35 return true; 36 default: 37 return false; 38 } 39 } 40 41 Attribute SparseTensorEncodingAttr::parse(DialectAsmParser &parser, Type type) { 42 if (failed(parser.parseLess())) 43 return {}; 44 // Parse the data as a dictionary. 45 DictionaryAttr dict; 46 if (failed(parser.parseAttribute(dict))) 47 return {}; 48 if (failed(parser.parseGreater())) 49 return {}; 50 // Process the data from the parsed dictionary value into struct-like data. 51 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 52 AffineMap map = {}; 53 unsigned ptr = 0; 54 unsigned ind = 0; 55 for (const NamedAttribute &attr : dict) { 56 if (attr.first == "dimLevelType") { 57 auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); 58 if (!arrayAttr) { 59 parser.emitError(parser.getNameLoc(), 60 "expected an array for dimension level types"); 61 return {}; 62 } 63 for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) { 64 auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); 65 if (!strAttr) { 66 parser.emitError(parser.getNameLoc(), 67 "expected a string value in dimension level types"); 68 return {}; 69 } 70 auto strVal = strAttr.getValue(); 71 if (strVal == "dense") { 72 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 73 } else if (strVal == "compressed") { 74 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 75 } else if (strVal == "singleton") { 76 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 77 } else { 78 parser.emitError(parser.getNameLoc(), 79 "unexpected dimension level type: ") 80 << strVal; 81 return {}; 82 } 83 } 84 } else if (attr.first == "dimOrdering") { 85 auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); 86 if (!affineAttr) { 87 parser.emitError(parser.getNameLoc(), 88 "expected an affine map for dimension ordering"); 89 return {}; 90 } 91 map = affineAttr.getValue(); 92 } else if (attr.first == "pointerBitWidth") { 93 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 94 if (!intAttr) { 95 parser.emitError(parser.getNameLoc(), 96 "expected an integral pointer bitwidth"); 97 return {}; 98 } 99 ptr = intAttr.getInt(); 100 } else if (attr.first == "indexBitWidth") { 101 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 102 if (!intAttr) { 103 parser.emitError(parser.getNameLoc(), 104 "expected an integral index bitwidth"); 105 return {}; 106 } 107 ind = intAttr.getInt(); 108 } else { 109 parser.emitError(parser.getNameLoc(), "unexpected key: ") 110 << attr.first.str(); 111 return {}; 112 } 113 } 114 // Construct struct-like storage for attribute. 115 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 116 map, ptr, ind); 117 } 118 119 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const { 120 // Print the struct-like storage in dictionary fashion. 121 printer << "encoding<{ dimLevelType = [ "; 122 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 123 switch (getDimLevelType()[i]) { 124 case DimLevelType::Dense: 125 printer << "\"dense\""; 126 break; 127 case DimLevelType::Compressed: 128 printer << "\"compressed\""; 129 break; 130 case DimLevelType::Singleton: 131 printer << "\"singleton\""; 132 break; 133 } 134 if (i != e - 1) 135 printer << ", "; 136 } 137 printer << " ]"; 138 if (getDimOrdering()) 139 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 140 printer << ", pointerBitWidth = " << getPointerBitWidth() 141 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 142 } 143 144 LogicalResult SparseTensorEncodingAttr::verify( 145 function_ref<InFlightDiagnostic()> emitError, 146 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 147 unsigned pointerBitWidth, unsigned indexBitWidth) { 148 if (!acceptBitWidth(pointerBitWidth)) 149 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 150 if (!acceptBitWidth(indexBitWidth)) 151 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 152 if (dimOrdering) { 153 if (!dimOrdering.isPermutation()) 154 return emitError() 155 << "expected a permutation affine map for dimension ordering"; 156 if (dimOrdering.getNumResults() != dimLevelType.size()) 157 return emitError() << "unexpected mismatch in ordering and dimension " 158 "level types size"; 159 } 160 return success(); 161 } 162 163 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 164 ArrayRef<int64_t> shape, Type elementType, 165 function_ref<InFlightDiagnostic()> emitError) const { 166 // Check structural integrity. 167 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 168 getPointerBitWidth(), getIndexBitWidth()))) 169 return failure(); 170 // Check integrity with tensor type specifics. Dimension ordering is optional, 171 // but we always should have dimension level types for the full rank. 172 unsigned size = shape.size(); 173 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 174 return emitError() << "expected an affine map of size " << size 175 << " for dimension ordering"; 176 if (getDimLevelType().size() != size) 177 return emitError() << "expected an array of size " << size 178 << " for dimension level types"; 179 return success(); 180 } 181 182 SparseTensorEncodingAttr 183 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 184 if (auto ttp = type.dyn_cast<RankedTensorType>()) 185 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 186 return nullptr; 187 } 188 189 //===----------------------------------------------------------------------===// 190 // TensorDialect Operations. 191 //===----------------------------------------------------------------------===// 192 193 static LogicalResult isInBounds(Value dim, Value tensor) { 194 if (auto constantOp = dim.getDefiningOp<ConstantOp>()) { 195 unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt(); 196 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 197 return failure(); 198 } 199 return success(); // in bounds, or symbolic 200 } 201 202 static LogicalResult isMatchingWidth(Value result, unsigned width) { 203 Type etp = result.getType().cast<MemRefType>().getElementType(); 204 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 205 return success(); 206 return failure(); 207 } 208 209 static LogicalResult verify(NewOp op) { 210 if (!getSparseTensorEncoding(op.result().getType())) 211 return op.emitError("expected a sparse tensor result"); 212 return success(); 213 } 214 215 static LogicalResult verify(ConvertOp op) { 216 if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 217 if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 218 assert(tp1.getRank() == tp2.getRank()); 219 auto shape1 = tp1.getShape(); 220 auto shape2 = tp2.getShape(); 221 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { 222 if (shape1[d] != shape2[d]) 223 return op.emitError() 224 << "unexpected conversion mismatch in dimension " << d; 225 } 226 return success(); 227 } 228 } 229 return op.emitError("unexpected type in convert"); 230 } 231 232 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 233 if (getType() == source().getType()) 234 return source(); 235 return {}; 236 } 237 238 static LogicalResult verify(ToPointersOp op) { 239 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 240 if (failed(isInBounds(op.dim(), op.tensor()))) 241 return op.emitError("requested pointers dimension out of bounds"); 242 if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 243 return op.emitError("unexpected type for pointers"); 244 return success(); 245 } 246 return op.emitError("expected a sparse tensor to get pointers"); 247 } 248 249 static LogicalResult verify(ToIndicesOp op) { 250 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 251 if (failed(isInBounds(op.dim(), op.tensor()))) 252 return op.emitError("requested indices dimension out of bounds"); 253 if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 254 return op.emitError("unexpected type for indices"); 255 return success(); 256 } 257 return op.emitError("expected a sparse tensor to get indices"); 258 } 259 260 static LogicalResult verify(ToValuesOp op) { 261 if (!getSparseTensorEncoding(op.tensor().getType())) 262 return op.emitError("expected a sparse tensor to get values"); 263 RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 264 MemRefType mtp = op.result().getType().cast<MemRefType>(); 265 if (ttp.getElementType() != mtp.getElementType()) 266 return op.emitError("unexpected mismatch in element types"); 267 return success(); 268 } 269 270 static LogicalResult verify(ToTensorOp op) { 271 if (!getSparseTensorEncoding(op.result().getType())) 272 return op.emitError("expected a sparse tensor as result"); 273 return success(); 274 } 275 276 //===----------------------------------------------------------------------===// 277 // TensorDialect Methods. 278 //===----------------------------------------------------------------------===// 279 280 void SparseTensorDialect::initialize() { 281 addAttributes< 282 #define GET_ATTRDEF_LIST 283 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 284 >(); 285 addOperations< 286 #define GET_OP_LIST 287 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 288 >(); 289 } 290 291 #define GET_OP_CLASSES 292 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 293 294 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, 295 Type type) const { 296 StringRef attrTag; 297 if (failed(parser.parseKeyword(&attrTag))) 298 return Attribute(); 299 Attribute attr; 300 auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); 301 if (parseResult.hasValue()) 302 return attr; 303 parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); 304 return Attribute(); 305 } 306 307 void SparseTensorDialect::printAttribute(Attribute attr, 308 DialectAsmPrinter &printer) const { 309 if (succeeded(generatedAttributePrinter(attr, printer))) 310 return; 311 } 312