1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/DialectImplementation.h" 14 #include "mlir/IR/OpImplementation.h" 15 #include "llvm/ADT/TypeSwitch.h" 16 17 using namespace mlir; 18 using namespace mlir::sparse_tensor; 19 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 21 22 //===----------------------------------------------------------------------===// 23 // TensorDialect Attribute Methods. 24 //===----------------------------------------------------------------------===// 25 26 #define GET_ATTRDEF_CLASSES 27 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 28 29 static bool acceptBitWidth(unsigned bitWidth) { 30 switch (bitWidth) { 31 case 0: 32 case 8: 33 case 16: 34 case 32: 35 case 64: 36 return true; 37 default: 38 return false; 39 } 40 } 41 42 Attribute SparseTensorEncodingAttr::parse(DialectAsmParser &parser, Type type) { 43 if (failed(parser.parseLess())) 44 return {}; 45 // Parse the data as a dictionary. 46 DictionaryAttr dict; 47 if (failed(parser.parseAttribute(dict))) 48 return {}; 49 if (failed(parser.parseGreater())) 50 return {}; 51 // Process the data from the parsed dictionary value into struct-like data. 52 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 53 AffineMap map = {}; 54 unsigned ptr = 0; 55 unsigned ind = 0; 56 for (const NamedAttribute &attr : dict) { 57 if (attr.first == "dimLevelType") { 58 auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); 59 if (!arrayAttr) { 60 parser.emitError(parser.getNameLoc(), 61 "expected an array for dimension level types"); 62 return {}; 63 } 64 for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) { 65 auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); 66 if (!strAttr) { 67 parser.emitError(parser.getNameLoc(), 68 "expected a string value in dimension level types"); 69 return {}; 70 } 71 auto strVal = strAttr.getValue(); 72 if (strVal == "dense") { 73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 74 } else if (strVal == "compressed") { 75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 76 } else if (strVal == "singleton") { 77 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 78 } else { 79 parser.emitError(parser.getNameLoc(), 80 "unexpected dimension level type: ") 81 << strVal; 82 return {}; 83 } 84 } 85 } else if (attr.first == "dimOrdering") { 86 auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); 87 if (!affineAttr) { 88 parser.emitError(parser.getNameLoc(), 89 "expected an affine map for dimension ordering"); 90 return {}; 91 } 92 map = affineAttr.getValue(); 93 } else if (attr.first == "pointerBitWidth") { 94 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 95 if (!intAttr) { 96 parser.emitError(parser.getNameLoc(), 97 "expected an integral pointer bitwidth"); 98 return {}; 99 } 100 ptr = intAttr.getInt(); 101 } else if (attr.first == "indexBitWidth") { 102 auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 103 if (!intAttr) { 104 parser.emitError(parser.getNameLoc(), 105 "expected an integral index bitwidth"); 106 return {}; 107 } 108 ind = intAttr.getInt(); 109 } else { 110 parser.emitError(parser.getNameLoc(), "unexpected key: ") 111 << attr.first.str(); 112 return {}; 113 } 114 } 115 // Construct struct-like storage for attribute. 116 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 117 map, ptr, ind); 118 } 119 120 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const { 121 // Print the struct-like storage in dictionary fashion. 122 printer << "encoding<{ dimLevelType = [ "; 123 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 124 switch (getDimLevelType()[i]) { 125 case DimLevelType::Dense: 126 printer << "\"dense\""; 127 break; 128 case DimLevelType::Compressed: 129 printer << "\"compressed\""; 130 break; 131 case DimLevelType::Singleton: 132 printer << "\"singleton\""; 133 break; 134 } 135 if (i != e - 1) 136 printer << ", "; 137 } 138 printer << " ]"; 139 if (getDimOrdering()) 140 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 141 printer << ", pointerBitWidth = " << getPointerBitWidth() 142 << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 143 } 144 145 LogicalResult SparseTensorEncodingAttr::verify( 146 function_ref<InFlightDiagnostic()> emitError, 147 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 148 unsigned pointerBitWidth, unsigned indexBitWidth) { 149 if (!acceptBitWidth(pointerBitWidth)) 150 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 151 if (!acceptBitWidth(indexBitWidth)) 152 return emitError() << "unexpected index bitwidth: " << indexBitWidth; 153 if (dimOrdering) { 154 if (!dimOrdering.isPermutation()) 155 return emitError() 156 << "expected a permutation affine map for dimension ordering"; 157 if (dimOrdering.getNumResults() != dimLevelType.size()) 158 return emitError() << "unexpected mismatch in ordering and dimension " 159 "level types size"; 160 } 161 return success(); 162 } 163 164 LogicalResult SparseTensorEncodingAttr::verifyEncoding( 165 ArrayRef<int64_t> shape, Type elementType, 166 function_ref<InFlightDiagnostic()> emitError) const { 167 // Check structural integrity. 168 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 169 getPointerBitWidth(), getIndexBitWidth()))) 170 return failure(); 171 // Check integrity with tensor type specifics. Dimension ordering is optional, 172 // but we always should have dimension level types for the full rank. 173 unsigned size = shape.size(); 174 if (getDimOrdering() && getDimOrdering().getNumResults() != size) 175 return emitError() << "expected an affine map of size " << size 176 << " for dimension ordering"; 177 if (getDimLevelType().size() != size) 178 return emitError() << "expected an array of size " << size 179 << " for dimension level types"; 180 return success(); 181 } 182 183 SparseTensorEncodingAttr 184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 185 if (auto ttp = type.dyn_cast<RankedTensorType>()) 186 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 187 return nullptr; 188 } 189 190 //===----------------------------------------------------------------------===// 191 // TensorDialect Operations. 192 //===----------------------------------------------------------------------===// 193 194 static LogicalResult isInBounds(Value dim, Value tensor) { 195 if (auto constantOp = dim.getDefiningOp<arith::ConstantOp>()) { 196 unsigned d = constantOp.value().cast<IntegerAttr>().getInt(); 197 if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 198 return failure(); 199 } 200 return success(); // in bounds, or symbolic 201 } 202 203 static LogicalResult isMatchingWidth(Value result, unsigned width) { 204 Type etp = result.getType().cast<MemRefType>().getElementType(); 205 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 206 return success(); 207 return failure(); 208 } 209 210 static LogicalResult verify(NewOp op) { 211 if (!getSparseTensorEncoding(op.result().getType())) 212 return op.emitError("expected a sparse tensor result"); 213 return success(); 214 } 215 216 static LogicalResult verify(InitOp op) { 217 if (!getSparseTensorEncoding(op.result().getType())) 218 return op.emitError("expected a sparse tensor result"); 219 RankedTensorType ttp = op.getType().cast<RankedTensorType>(); 220 unsigned rank = ttp.getRank(); 221 if (rank != op.sizes().size()) 222 return op.emitError("unexpected mismatch between tensor rank and sizes: ") 223 << rank << " vs. " << op.sizes().size(); 224 auto shape = ttp.getShape(); 225 for (unsigned i = 0; i < rank; i++) { 226 if (shape[i] == ShapedType::kDynamicSize) 227 continue; 228 auto constantOp = op.sizes()[i].getDefiningOp<arith::ConstantOp>(); 229 if (!constantOp || 230 constantOp.value().cast<IntegerAttr>().getInt() != shape[i]) 231 return op.emitError("unexpected mismatch with static dimension size ") 232 << shape[i]; 233 } 234 return success(); 235 } 236 237 static LogicalResult verify(ConvertOp op) { 238 if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 239 if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 240 assert(tp1.getRank() == tp2.getRank()); 241 auto shape1 = tp1.getShape(); 242 auto shape2 = tp2.getShape(); 243 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { 244 if (shape1[d] != shape2[d]) 245 return op.emitError("unexpected conversion mismatch in dimension ") 246 << d; 247 } 248 return success(); 249 } 250 } 251 return op.emitError("unexpected type in convert"); 252 } 253 254 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 255 if (getType() == source().getType()) 256 return source(); 257 return {}; 258 } 259 260 static LogicalResult verify(ReleaseOp op) { 261 if (!getSparseTensorEncoding(op.tensor().getType())) 262 return op.emitError("expected a sparse tensor to release"); 263 return success(); 264 } 265 266 static LogicalResult verify(ToPointersOp op) { 267 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 268 if (failed(isInBounds(op.dim(), op.tensor()))) 269 return op.emitError("requested pointers dimension out of bounds"); 270 if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 271 return op.emitError("unexpected type for pointers"); 272 return success(); 273 } 274 return op.emitError("expected a sparse tensor to get pointers"); 275 } 276 277 static LogicalResult verify(ToIndicesOp op) { 278 if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 279 if (failed(isInBounds(op.dim(), op.tensor()))) 280 return op.emitError("requested indices dimension out of bounds"); 281 if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 282 return op.emitError("unexpected type for indices"); 283 return success(); 284 } 285 return op.emitError("expected a sparse tensor to get indices"); 286 } 287 288 static LogicalResult verify(ToValuesOp op) { 289 if (!getSparseTensorEncoding(op.tensor().getType())) 290 return op.emitError("expected a sparse tensor to get values"); 291 RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 292 MemRefType mtp = op.result().getType().cast<MemRefType>(); 293 if (ttp.getElementType() != mtp.getElementType()) 294 return op.emitError("unexpected mismatch in element types"); 295 return success(); 296 } 297 298 static LogicalResult verify(ToTensorOp op) { 299 if (!getSparseTensorEncoding(op.result().getType())) 300 return op.emitError("expected a sparse tensor result"); 301 return success(); 302 } 303 304 //===----------------------------------------------------------------------===// 305 // TensorDialect Methods. 306 //===----------------------------------------------------------------------===// 307 308 void SparseTensorDialect::initialize() { 309 addAttributes< 310 #define GET_ATTRDEF_LIST 311 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 312 >(); 313 addOperations< 314 #define GET_OP_LIST 315 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 316 >(); 317 } 318 319 #define GET_OP_CLASSES 320 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 321 322 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, 323 Type type) const { 324 StringRef attrTag; 325 if (failed(parser.parseKeyword(&attrTag))) 326 return Attribute(); 327 Attribute attr; 328 auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); 329 if (parseResult.hasValue()) 330 return attr; 331 parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); 332 return Attribute(); 333 } 334 335 void SparseTensorDialect::printAttribute(Attribute attr, 336 DialectAsmPrinter &printer) const { 337 if (succeeded(generatedAttributePrinter(attr, printer))) 338 return; 339 } 340