1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2319072f4SAart Bik // 3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information. 5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6319072f4SAart Bik // 7319072f4SAart Bik //===----------------------------------------------------------------------===// 8319072f4SAart Bik 9319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 10319072f4SAart Bik #include "mlir/IR/Builders.h" 110a292199SAart Bik #include "mlir/IR/DialectImplementation.h" 12*65e7cd13SRiver Riddle #include "mlir/IR/Matchers.h" 13319072f4SAart Bik #include "mlir/IR/OpImplementation.h" 140a292199SAart Bik #include "llvm/ADT/TypeSwitch.h" 15319072f4SAart Bik 16319072f4SAart Bik using namespace mlir; 17319072f4SAart Bik using namespace mlir::sparse_tensor; 18319072f4SAart Bik 19485cc55eSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 20485cc55eSStella Laurenzo 210a292199SAart Bik //===----------------------------------------------------------------------===// 2296a23911SAart Bik // TensorDialect Attribute Methods. 230a292199SAart Bik //===----------------------------------------------------------------------===// 240a292199SAart Bik 250a292199SAart Bik #define GET_ATTRDEF_CLASSES 260a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 270a292199SAart Bik 280a292199SAart Bik static bool acceptBitWidth(unsigned bitWidth) { 290a292199SAart Bik switch (bitWidth) { 300a292199SAart Bik case 0: 310a292199SAart Bik case 8: 320a292199SAart Bik case 16: 330a292199SAart Bik case 32: 340a292199SAart Bik case 64: 350a292199SAart Bik return true; 360a292199SAart Bik default: 370a292199SAart Bik return false; 380a292199SAart Bik } 390a292199SAart Bik } 400a292199SAart Bik 41f97e72aaSMehdi Amini Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { 420a292199SAart Bik if (failed(parser.parseLess())) 430a292199SAart Bik return {}; 440a292199SAart Bik // Parse the data as a dictionary. 450a292199SAart Bik DictionaryAttr dict; 460a292199SAart Bik if (failed(parser.parseAttribute(dict))) 470a292199SAart Bik return {}; 480a292199SAart Bik if (failed(parser.parseGreater())) 490a292199SAart Bik return {}; 500a292199SAart Bik // Process the data from the parsed dictionary value into struct-like data. 510a292199SAart Bik SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 520a292199SAart Bik AffineMap map = {}; 530a292199SAart Bik unsigned ptr = 0; 540a292199SAart Bik unsigned ind = 0; 550a292199SAart Bik for (const NamedAttribute &attr : dict) { 560c7890c8SRiver Riddle if (attr.getName() == "dimLevelType") { 570c7890c8SRiver Riddle auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>(); 580a292199SAart Bik if (!arrayAttr) { 590a292199SAart Bik parser.emitError(parser.getNameLoc(), 600a292199SAart Bik "expected an array for dimension level types"); 610a292199SAart Bik return {}; 620a292199SAart Bik } 63e5639b3fSMehdi Amini for (auto i : arrayAttr) { 64e5639b3fSMehdi Amini auto strAttr = i.dyn_cast<StringAttr>(); 650a292199SAart Bik if (!strAttr) { 660a292199SAart Bik parser.emitError(parser.getNameLoc(), 670a292199SAart Bik "expected a string value in dimension level types"); 680a292199SAart Bik return {}; 690a292199SAart Bik } 700a292199SAart Bik auto strVal = strAttr.getValue(); 710a292199SAart Bik if (strVal == "dense") { 720a292199SAart Bik dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 730a292199SAart Bik } else if (strVal == "compressed") { 740a292199SAart Bik dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 750a292199SAart Bik } else if (strVal == "singleton") { 760a292199SAart Bik dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 770a292199SAart Bik } else { 780a292199SAart Bik parser.emitError(parser.getNameLoc(), 790a292199SAart Bik "unexpected dimension level type: ") 800a292199SAart Bik << strVal; 810a292199SAart Bik return {}; 820a292199SAart Bik } 830a292199SAart Bik } 840c7890c8SRiver Riddle } else if (attr.getName() == "dimOrdering") { 850c7890c8SRiver Riddle auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>(); 860a292199SAart Bik if (!affineAttr) { 870a292199SAart Bik parser.emitError(parser.getNameLoc(), 880a292199SAart Bik "expected an affine map for dimension ordering"); 890a292199SAart Bik return {}; 900a292199SAart Bik } 910a292199SAart Bik map = affineAttr.getValue(); 920c7890c8SRiver Riddle } else if (attr.getName() == "pointerBitWidth") { 930c7890c8SRiver Riddle auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 940a292199SAart Bik if (!intAttr) { 950a292199SAart Bik parser.emitError(parser.getNameLoc(), 960a292199SAart Bik "expected an integral pointer bitwidth"); 970a292199SAart Bik return {}; 980a292199SAart Bik } 990a292199SAart Bik ptr = intAttr.getInt(); 1000c7890c8SRiver Riddle } else if (attr.getName() == "indexBitWidth") { 1010c7890c8SRiver Riddle auto intAttr = attr.getValue().dyn_cast<IntegerAttr>(); 1020a292199SAart Bik if (!intAttr) { 1030a292199SAart Bik parser.emitError(parser.getNameLoc(), 1040a292199SAart Bik "expected an integral index bitwidth"); 1050a292199SAart Bik return {}; 1060a292199SAart Bik } 1070a292199SAart Bik ind = intAttr.getInt(); 1080a292199SAart Bik } else { 1090a292199SAart Bik parser.emitError(parser.getNameLoc(), "unexpected key: ") 1100c7890c8SRiver Riddle << attr.getName().strref(); 1110a292199SAart Bik return {}; 1120a292199SAart Bik } 1130a292199SAart Bik } 1140a292199SAart Bik // Construct struct-like storage for attribute. 115fb093c83SChris Lattner return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt, 116fb093c83SChris Lattner map, ptr, ind); 1170a292199SAart Bik } 1180a292199SAart Bik 119f97e72aaSMehdi Amini void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { 1200a292199SAart Bik // Print the struct-like storage in dictionary fashion. 121f30a8a6fSMehdi Amini printer << "<{ dimLevelType = [ "; 1220a292199SAart Bik for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 1230a292199SAart Bik switch (getDimLevelType()[i]) { 1240a292199SAart Bik case DimLevelType::Dense: 1250a292199SAart Bik printer << "\"dense\""; 1260a292199SAart Bik break; 1270a292199SAart Bik case DimLevelType::Compressed: 1280a292199SAart Bik printer << "\"compressed\""; 1290a292199SAart Bik break; 1300a292199SAart Bik case DimLevelType::Singleton: 1310a292199SAart Bik printer << "\"singleton\""; 1320a292199SAart Bik break; 1330a292199SAart Bik } 1340a292199SAart Bik if (i != e - 1) 1350a292199SAart Bik printer << ", "; 1360a292199SAart Bik } 1370a292199SAart Bik printer << " ]"; 1380a292199SAart Bik if (getDimOrdering()) 1390a292199SAart Bik printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 1400a292199SAart Bik printer << ", pointerBitWidth = " << getPointerBitWidth() 1410a292199SAart Bik << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 1420a292199SAart Bik } 1430a292199SAart Bik 1440a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verify( 1450a292199SAart Bik function_ref<InFlightDiagnostic()> emitError, 1460a292199SAart Bik ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 1470a292199SAart Bik unsigned pointerBitWidth, unsigned indexBitWidth) { 1480a292199SAart Bik if (!acceptBitWidth(pointerBitWidth)) 1490a292199SAart Bik return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 1500a292199SAart Bik if (!acceptBitWidth(indexBitWidth)) 1510a292199SAart Bik return emitError() << "unexpected index bitwidth: " << indexBitWidth; 1520a292199SAart Bik if (dimOrdering) { 1530a292199SAart Bik if (!dimOrdering.isPermutation()) 1540a292199SAart Bik return emitError() 1550a292199SAart Bik << "expected a permutation affine map for dimension ordering"; 1560a292199SAart Bik if (dimOrdering.getNumResults() != dimLevelType.size()) 1570a292199SAart Bik return emitError() << "unexpected mismatch in ordering and dimension " 1580a292199SAart Bik "level types size"; 1590a292199SAart Bik } 1600a292199SAart Bik return success(); 1610a292199SAart Bik } 1620a292199SAart Bik 1630a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding( 1640a292199SAart Bik ArrayRef<int64_t> shape, Type elementType, 1650a292199SAart Bik function_ref<InFlightDiagnostic()> emitError) const { 1660a292199SAart Bik // Check structural integrity. 1670a292199SAart Bik if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 1680a292199SAart Bik getPointerBitWidth(), getIndexBitWidth()))) 1690a292199SAart Bik return failure(); 1700a292199SAart Bik // Check integrity with tensor type specifics. Dimension ordering is optional, 1710a292199SAart Bik // but we always should have dimension level types for the full rank. 1720a292199SAart Bik unsigned size = shape.size(); 1734aa9b398SAart Bik if (size == 0) 1744aa9b398SAart Bik return emitError() << "expected non-scalar sparse tensor"; 1750a292199SAart Bik if (getDimOrdering() && getDimOrdering().getNumResults() != size) 1760a292199SAart Bik return emitError() << "expected an affine map of size " << size 1770a292199SAart Bik << " for dimension ordering"; 1780a292199SAart Bik if (getDimLevelType().size() != size) 1790a292199SAart Bik return emitError() << "expected an array of size " << size 1800a292199SAart Bik << " for dimension level types"; 1810a292199SAart Bik return success(); 1820a292199SAart Bik } 1830a292199SAart Bik 18496a23911SAart Bik SparseTensorEncodingAttr 18596a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 18696a23911SAart Bik if (auto ttp = type.dyn_cast<RankedTensorType>()) 18796a23911SAart Bik return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 18896a23911SAart Bik return nullptr; 18996a23911SAart Bik } 19096a23911SAart Bik 1910a292199SAart Bik //===----------------------------------------------------------------------===// 19296a23911SAart Bik // TensorDialect Operations. 19396a23911SAart Bik //===----------------------------------------------------------------------===// 19496a23911SAart Bik 19596a23911SAart Bik static LogicalResult isInBounds(Value dim, Value tensor) { 196*65e7cd13SRiver Riddle IntegerAttr constantAttr; 197*65e7cd13SRiver Riddle if (matchPattern(dim, m_Constant(&constantAttr))) { 198*65e7cd13SRiver Riddle unsigned d = constantAttr.getInt(); 19996a23911SAart Bik if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 20096a23911SAart Bik return failure(); 20196a23911SAart Bik } 20296a23911SAart Bik return success(); // in bounds, or symbolic 20396a23911SAart Bik } 20496a23911SAart Bik 20596a23911SAart Bik static LogicalResult isMatchingWidth(Value result, unsigned width) { 20696a23911SAart Bik Type etp = result.getType().cast<MemRefType>().getElementType(); 20796a23911SAart Bik if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 20896a23911SAart Bik return success(); 20996a23911SAart Bik return failure(); 21096a23911SAart Bik } 21196a23911SAart Bik 21296a23911SAart Bik static LogicalResult verify(NewOp op) { 213697ea09dSAart Bik if (!getSparseTensorEncoding(op.result().getType())) 21496a23911SAart Bik return op.emitError("expected a sparse tensor result"); 21596a23911SAart Bik return success(); 21696a23911SAart Bik } 21796a23911SAart Bik 21835517a25SAart Bik static LogicalResult verify(InitOp op) { 21935517a25SAart Bik if (!getSparseTensorEncoding(op.result().getType())) 22035517a25SAart Bik return op.emitError("expected a sparse tensor result"); 22135517a25SAart Bik RankedTensorType ttp = op.getType().cast<RankedTensorType>(); 22235517a25SAart Bik unsigned rank = ttp.getRank(); 22335517a25SAart Bik if (rank != op.sizes().size()) 22435517a25SAart Bik return op.emitError("unexpected mismatch between tensor rank and sizes: ") 22535517a25SAart Bik << rank << " vs. " << op.sizes().size(); 22635517a25SAart Bik auto shape = ttp.getShape(); 22735517a25SAart Bik for (unsigned i = 0; i < rank; i++) { 22835517a25SAart Bik if (shape[i] == ShapedType::kDynamicSize) 22935517a25SAart Bik continue; 230*65e7cd13SRiver Riddle IntegerAttr constantAttr; 231*65e7cd13SRiver Riddle if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) || 232*65e7cd13SRiver Riddle constantAttr.getInt() != shape[i]) { 23335517a25SAart Bik return op.emitError("unexpected mismatch with static dimension size ") 23435517a25SAart Bik << shape[i]; 23535517a25SAart Bik } 236*65e7cd13SRiver Riddle } 23735517a25SAart Bik return success(); 23835517a25SAart Bik } 23935517a25SAart Bik 240697ea09dSAart Bik static LogicalResult verify(ConvertOp op) { 241697ea09dSAart Bik if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 242697ea09dSAart Bik if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 2431e6ef0cfSAart Bik if (tp1.getRank() != tp2.getRank()) 2441e6ef0cfSAart Bik return op.emitError("unexpected conversion mismatch in rank"); 245697ea09dSAart Bik auto shape1 = tp1.getShape(); 246697ea09dSAart Bik auto shape2 = tp2.getShape(); 2479d1db3d4SAart Bik // Accept size matches between the source and the destination type 2489d1db3d4SAart Bik // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or 2499d1db3d4SAart Bik // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). 25005c7f450SAart Bik for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { 2519d1db3d4SAart Bik if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) 25235517a25SAart Bik return op.emitError("unexpected conversion mismatch in dimension ") 25335517a25SAart Bik << d; 25405c7f450SAart Bik } 255697ea09dSAart Bik return success(); 256697ea09dSAart Bik } 257697ea09dSAart Bik } 258697ea09dSAart Bik return op.emitError("unexpected type in convert"); 259697ea09dSAart Bik } 260697ea09dSAart Bik 261066d786cSAart Bik OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { 262066d786cSAart Bik if (getType() == source().getType()) 263066d786cSAart Bik return source(); 264066d786cSAart Bik return {}; 265066d786cSAart Bik } 266066d786cSAart Bik 26796a23911SAart Bik static LogicalResult verify(ToPointersOp op) { 268c2415d67SAart Bik if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 26996a23911SAart Bik if (failed(isInBounds(op.dim(), op.tensor()))) 27096a23911SAart Bik return op.emitError("requested pointers dimension out of bounds"); 27196a23911SAart Bik if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 27296a23911SAart Bik return op.emitError("unexpected type for pointers"); 27396a23911SAart Bik return success(); 27496a23911SAart Bik } 27596a23911SAart Bik return op.emitError("expected a sparse tensor to get pointers"); 27696a23911SAart Bik } 27796a23911SAart Bik 27896a23911SAart Bik static LogicalResult verify(ToIndicesOp op) { 279c2415d67SAart Bik if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 28096a23911SAart Bik if (failed(isInBounds(op.dim(), op.tensor()))) 28196a23911SAart Bik return op.emitError("requested indices dimension out of bounds"); 28296a23911SAart Bik if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 28396a23911SAart Bik return op.emitError("unexpected type for indices"); 28496a23911SAart Bik return success(); 28596a23911SAart Bik } 28696a23911SAart Bik return op.emitError("expected a sparse tensor to get indices"); 28796a23911SAart Bik } 28896a23911SAart Bik 28996a23911SAart Bik static LogicalResult verify(ToValuesOp op) { 29096a23911SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 29196a23911SAart Bik return op.emitError("expected a sparse tensor to get values"); 29296a23911SAart Bik RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 29396a23911SAart Bik MemRefType mtp = op.result().getType().cast<MemRefType>(); 29496a23911SAart Bik if (ttp.getElementType() != mtp.getElementType()) 29596a23911SAart Bik return op.emitError("unexpected mismatch in element types"); 29696a23911SAart Bik return success(); 29796a23911SAart Bik } 29896a23911SAart Bik 299f66e5769SAart Bik //===----------------------------------------------------------------------===// 300f66e5769SAart Bik // TensorDialect Management Operations. 301f66e5769SAart Bik //===----------------------------------------------------------------------===// 302f66e5769SAart Bik 303f66e5769SAart Bik static LogicalResult verify(LexInsertOp op) { 304f66e5769SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 305f66e5769SAart Bik return op.emitError("expected a sparse tensor for insertion"); 306f66e5769SAart Bik return success(); 307f66e5769SAart Bik } 308f66e5769SAart Bik 3094f2ec7f9SAart Bik static LogicalResult verify(ExpandOp op) { 3104f2ec7f9SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 3114f2ec7f9SAart Bik return op.emitError("expected a sparse tensor for expansion"); 3124f2ec7f9SAart Bik return success(); 3134f2ec7f9SAart Bik } 3144f2ec7f9SAart Bik 3154f2ec7f9SAart Bik static LogicalResult verify(CompressOp op) { 3164f2ec7f9SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 3174f2ec7f9SAart Bik return op.emitError("expected a sparse tensor for compression"); 3184f2ec7f9SAart Bik return success(); 3194f2ec7f9SAart Bik } 3204f2ec7f9SAart Bik 321f66e5769SAart Bik static LogicalResult verify(LoadOp op) { 322f66e5769SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 323f66e5769SAart Bik return op.emitError("expected a sparse tensor to materialize"); 324f66e5769SAart Bik return success(); 325f66e5769SAart Bik } 326f66e5769SAart Bik 327f66e5769SAart Bik static LogicalResult verify(ReleaseOp op) { 328f66e5769SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 329f66e5769SAart Bik return op.emitError("expected a sparse tensor to release"); 33036b66ab9SAart Bik return success(); 331727a63e0SAart Bik } 332727a63e0SAart Bik 333efa15f41SAart Bik static LogicalResult verify(OutOp op) { 334efa15f41SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 335efa15f41SAart Bik return op.emitError("expected a sparse tensor for output"); 336efa15f41SAart Bik return success(); 337efa15f41SAart Bik } 338efa15f41SAart Bik 33996a23911SAart Bik //===----------------------------------------------------------------------===// 34096a23911SAart Bik // TensorDialect Methods. 3410a292199SAart Bik //===----------------------------------------------------------------------===// 3420a292199SAart Bik 343319072f4SAart Bik void SparseTensorDialect::initialize() { 3440a292199SAart Bik addAttributes< 3450a292199SAart Bik #define GET_ATTRDEF_LIST 3460a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 3470a292199SAart Bik >(); 348319072f4SAart Bik addOperations< 349319072f4SAart Bik #define GET_OP_LIST 350319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 351319072f4SAart Bik >(); 352319072f4SAart Bik } 353319072f4SAart Bik 354319072f4SAart Bik #define GET_OP_CLASSES 355319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 356