1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // TensorDialect Attribute Methods. 24 //===----------------------------------------------------------------------===// 25 26 #define GET_ATTRDEF_CLASSES 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 28 29 static bool acceptBitWidth(unsigned bitWidth) { 30 switch (bitWidth) { 31 case 0: 32 case 8: 33 case 16: 34 case 32: 35 case 64: 36 return true; 37 default: 38 return false; 39 } 40 } 41 42 Attribute SparseTensorEncodingAttr::parse(DialectAsmParser &parser, Type type) { 43 if (failed(parser.parseLess())) 44 return {}; 45 // Parse the data as a dictionary. 46 DictionaryAttr dict; 47 if (failed(parser.parseAttribute(dict))) 48 return {}; 49 if (failed(parser.parseGreater())) 50 return {}; 51 // Process the data from the parsed dictionary value into struct-like data. 52 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 53 AffineMap map = {}; 54 unsigned ptr = 0; 55 unsigned ind = 0; 56 for (const NamedAttribute &attr : dict) { 57 if (attr.first == "dimLevelType") { 58 auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); 59 if (!arrayAttr) { 60 parser.emitError(parser.getNameLoc(), 61 "expected an array for dimension level types"); 62 return {}; 63 } 64 for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) { 65 auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); 66 if (!strAttr) { 67 parser.emitError(parser.getNameLoc(), 68 "expected a string value in dimension level types"); 69 return {}; 70 } 71 auto strVal = strAttr.getValue(); 72 if (strVal == "dense") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 74 } else if (strVal == "compressed") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 76 } else if (strVal == "singleton") { 77 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 78 } else { 79 parser.emitError(parser.getNameLoc(), 80 "unexpected dimension level type: ") 81 << strVal; 82 return {}; 83 } 84 } 85 } else if (attr.first == "dimOrdering") { 86 auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); 87 if (!affineAttr) { 88 parser.emitError(parser.getNameLoc(), 89 "expected an affine map for dimension ordering"); 90 return {}; 91 } 92 map = affineAttr.getValue(); 93 } else if (attr.first == "pointerBitWidth") { 94 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 95 if (!intAttr) { 96 parser.emitError(parser.getNameLoc(), 97 "expected an integral pointer bitwidth"); 98 return {}; 99 } 100 ptr = intAttr.getInt(); 101 } else if (attr.first == "indexBitWidth") { 102 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 103 if (!intAttr) { 104 parser.emitError(parser.getNameLoc(), 105 "expected an integral index bitwidth"); 106 return {}; 107 } 108 ind = intAttr.getInt(); 109 } else { 110 parser.emitError(parser.getNameLoc(), "unexpected key: ") 111 << attr.first.str(); 112 return {}; 113 } 114 } 115 // Construct struct-like storage for attribute. 116 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 117 map, ptr, ind); 118 } 119 120 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const { 121 // Print the struct-like storage in dictionary fashion. 122 printer << "encoding<{ dimLevelType = [ "; 123 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 124 switch (getDimLevelType()[i]) { 125 case DimLevelType::Dense: 126 printer << "\"dense\""; 127 break; 128 case DimLevelType::Compressed: 129 printer << "\"compressed\""; 130 break; 131 case DimLevelType::Singleton: 132 printer << "\"singleton\""; 133 break; 134 } 135 if (i != e - 1) 136 printer << ", "; 137 } 138 printer << " ]"; 139 if (getDimOrdering()) 140 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 141 printer << ", pointerBitWidth = " << getPointerBitWidth() 142 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 143 } 144 145 LogicalResult SparseTensorEncodingAttr::verify( 146 function_ref<InFlightDiagnostic()> emitError, 147 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 148 unsigned pointerBitWidth, unsigned indexBitWidth) { 149 if (!acceptBitWidth(pointerBitWidth)) 150 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 151 if (!acceptBitWidth(indexBitWidth)) 152 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 153 if (dimOrdering) { 154 if (!dimOrdering.isPermutation()) 155 return emitError() 156 << "expected a permutation affine map for dimension ordering"; 157 if (dimOrdering.getNumResults() != dimLevelType.size()) 158 return emitError() << "unexpected mismatch in ordering and dimension " 159 "level types size"; 160 } 161 return success(); 162 } 163 164 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 165 ArrayRef<int64_t> shape, Type elementType, 166 function_ref<InFlightDiagnostic()> emitError) const { 167 // Check structural integrity. 168 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 169 getPointerBitWidth(), getIndexBitWidth()))) 170 return failure(); 171 // Check integrity with tensor type specifics. Dimension ordering is optional, 172 // but we always should have dimension level types for the full rank. 173 unsigned size = shape.size(); 174 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 175 return emitError() << "expected an affine map of size " << size 176 << " for dimension ordering"; 177 if (getDimLevelType().size() != size) 178 return emitError() << "expected an array of size " << size 179 << " for dimension level types"; 180 return success(); 181 } 182 183 SparseTensorEncodingAttr 184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 185 if (auto ttp = type.dyn_cast<RankedTensorType>()) 186 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 187 return nullptr; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TensorDialect Operations. 192 //===----------------------------------------------------------------------===// 193 194 static LogicalResult isInBounds(Value dim, Value tensor) { 195 if (auto constantOp = dim.getDefiningOp<arith::ConstantOp>()) { 196 unsigned d = constantOp.value().cast<IntegerAttr>().getInt(); 197 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 198 return failure(); 199 } 200 return success(); // in bounds, or symbolic 201 } 202 203 static LogicalResult isMatchingWidth(Value result, unsigned width) { 204 Type etp = result.getType().cast<MemRefType>().getElementType(); 205 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 206 return success(); 207 return failure(); 208 } 209 210 static LogicalResult verify(NewOp op) { 211 if (!getSparseTensorEncoding(op.result().getType())) 212 return op.emitError("expected a sparse tensor result"); 213 return success(); 214 } 215 216 static LogicalResult verify(ConvertOp op) { 217 if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 218 if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 219 assert(tp1.getRank() == tp2.getRank()); 220 auto shape1 = tp1.getShape(); 221 auto shape2 = tp2.getShape(); 222 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { 223 if (shape1[d] != shape2[d]) 224 return op.emitError() 225 << "unexpected conversion mismatch in dimension " << d; 226 } 227 return success(); 228 } 229 } 230 return op.emitError("unexpected type in convert"); 231 } 232 233 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 234 if (getType() == source().getType()) 235 return source(); 236 return {}; 237 } 238 239 static LogicalResult verify(ReleaseOp op) { 240 if (!getSparseTensorEncoding(op.tensor().getType())) 241 return op.emitError("expected a sparse tensor to release"); 242 return success(); 243 } 244 245 static LogicalResult verify(ToPointersOp op) { 246 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 247 if (failed(isInBounds(op.dim(), op.tensor()))) 248 return op.emitError("requested pointers dimension out of bounds"); 249 if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 250 return op.emitError("unexpected type for pointers"); 251 return success(); 252 } 253 return op.emitError("expected a sparse tensor to get pointers"); 254 } 255 256 static LogicalResult verify(ToIndicesOp op) { 257 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 258 if (failed(isInBounds(op.dim(), op.tensor()))) 259 return op.emitError("requested indices dimension out of bounds"); 260 if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 261 return op.emitError("unexpected type for indices"); 262 return success(); 263 } 264 return op.emitError("expected a sparse tensor to get indices"); 265 } 266 267 static LogicalResult verify(ToValuesOp op) { 268 if (!getSparseTensorEncoding(op.tensor().getType())) 269 return op.emitError("expected a sparse tensor to get values"); 270 RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 271 MemRefType mtp = op.result().getType().cast<MemRefType>(); 272 if (ttp.getElementType() != mtp.getElementType()) 273 return op.emitError("unexpected mismatch in element types"); 274 return success(); 275 } 276 277 static LogicalResult verify(ToTensorOp op) { 278 if (!getSparseTensorEncoding(op.result().getType())) 279 return op.emitError("expected a sparse tensor as result"); 280 return success(); 281 } 282 283 //===----------------------------------------------------------------------===// 284 // TensorDialect Methods. 285 //===----------------------------------------------------------------------===// 286 287 void SparseTensorDialect::initialize() { 288 addAttributes< 289 #define GET_ATTRDEF_LIST 290 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 291 >(); 292 addOperations< 293 #define GET_OP_LIST 294 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 295 >(); 296 } 297 298 #define GET_OP_CLASSES 299 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 300 301 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, 302 Type type) const { 303 StringRef attrTag; 304 if (failed(parser.parseKeyword(&attrTag))) 305 return Attribute(); 306 Attribute attr; 307 auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); 308 if (parseResult.hasValue()) 309 return attr; 310 parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); 311 return Attribute(); 312 } 313 314 void SparseTensorDialect::printAttribute(Attribute attr, 315 DialectAsmPrinter &printer) const { 316 if (succeeded(generatedAttributePrinter(attr, printer))) 317 return; 318 } 319