1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// 2319072f4SAart Bik // 3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information. 5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6319072f4SAart Bik // 7319072f4SAart Bik //===----------------------------------------------------------------------===// 8319072f4SAart Bik 9319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 1096a23911SAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h" 11319072f4SAart Bik #include "mlir/IR/Builders.h" 120a292199SAart Bik #include "mlir/IR/DialectImplementation.h" 13319072f4SAart Bik #include "mlir/IR/OpImplementation.h" 140a292199SAart Bik #include "llvm/ADT/TypeSwitch.h" 15319072f4SAart Bik 16319072f4SAart Bik using namespace mlir; 17319072f4SAart Bik using namespace mlir::sparse_tensor; 18319072f4SAart Bik 19485cc55eSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" 20485cc55eSStella Laurenzo 210a292199SAart Bik //===----------------------------------------------------------------------===// 2296a23911SAart Bik // TensorDialect Attribute Methods. 230a292199SAart Bik //===----------------------------------------------------------------------===// 240a292199SAart Bik 250a292199SAart Bik #define GET_ATTRDEF_CLASSES 260a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 270a292199SAart Bik 280a292199SAart Bik static bool acceptBitWidth(unsigned bitWidth) { 290a292199SAart Bik switch (bitWidth) { 300a292199SAart Bik case 0: 310a292199SAart Bik case 8: 320a292199SAart Bik case 16: 330a292199SAart Bik case 32: 340a292199SAart Bik case 64: 350a292199SAart Bik return true; 360a292199SAart Bik default: 370a292199SAart Bik return false; 380a292199SAart Bik } 390a292199SAart Bik } 400a292199SAart Bik 410a292199SAart Bik Attribute SparseTensorEncodingAttr::parse(MLIRContext *context, 420a292199SAart Bik DialectAsmParser &parser, Type type) { 430a292199SAart Bik if (failed(parser.parseLess())) 440a292199SAart Bik return {}; 450a292199SAart Bik // Parse the data as a dictionary. 460a292199SAart Bik DictionaryAttr dict; 470a292199SAart Bik if (failed(parser.parseAttribute(dict))) 480a292199SAart Bik return {}; 490a292199SAart Bik if (failed(parser.parseGreater())) 500a292199SAart Bik return {}; 510a292199SAart Bik // Process the data from the parsed dictionary value into struct-like data. 520a292199SAart Bik SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt; 530a292199SAart Bik AffineMap map = {}; 540a292199SAart Bik unsigned ptr = 0; 550a292199SAart Bik unsigned ind = 0; 560a292199SAart Bik for (const NamedAttribute &attr : dict) { 570a292199SAart Bik if (attr.first == "dimLevelType") { 580a292199SAart Bik auto arrayAttr = attr.second.dyn_cast<ArrayAttr>(); 590a292199SAart Bik if (!arrayAttr) { 600a292199SAart Bik parser.emitError(parser.getNameLoc(), 610a292199SAart Bik "expected an array for dimension level types"); 620a292199SAart Bik return {}; 630a292199SAart Bik } 640a292199SAart Bik for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) { 650a292199SAart Bik auto strAttr = arrayAttr[i].dyn_cast<StringAttr>(); 660a292199SAart Bik if (!strAttr) { 670a292199SAart Bik parser.emitError(parser.getNameLoc(), 680a292199SAart Bik "expected a string value in dimension level types"); 690a292199SAart Bik return {}; 700a292199SAart Bik } 710a292199SAart Bik auto strVal = strAttr.getValue(); 720a292199SAart Bik if (strVal == "dense") { 730a292199SAart Bik dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense); 740a292199SAart Bik } else if (strVal == "compressed") { 750a292199SAart Bik dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed); 760a292199SAart Bik } else if (strVal == "singleton") { 770a292199SAart Bik dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton); 780a292199SAart Bik } else { 790a292199SAart Bik parser.emitError(parser.getNameLoc(), 800a292199SAart Bik "unexpected dimension level type: ") 810a292199SAart Bik << strVal; 820a292199SAart Bik return {}; 830a292199SAart Bik } 840a292199SAart Bik } 850a292199SAart Bik } else if (attr.first == "dimOrdering") { 860a292199SAart Bik auto affineAttr = attr.second.dyn_cast<AffineMapAttr>(); 870a292199SAart Bik if (!affineAttr) { 880a292199SAart Bik parser.emitError(parser.getNameLoc(), 890a292199SAart Bik "expected an affine map for dimension ordering"); 900a292199SAart Bik return {}; 910a292199SAart Bik } 920a292199SAart Bik map = affineAttr.getValue(); 930a292199SAart Bik } else if (attr.first == "pointerBitWidth") { 940a292199SAart Bik auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 950a292199SAart Bik if (!intAttr) { 960a292199SAart Bik parser.emitError(parser.getNameLoc(), 970a292199SAart Bik "expected an integral pointer bitwidth"); 980a292199SAart Bik return {}; 990a292199SAart Bik } 1000a292199SAart Bik ptr = intAttr.getInt(); 1010a292199SAart Bik } else if (attr.first == "indexBitWidth") { 1020a292199SAart Bik auto intAttr = attr.second.dyn_cast<IntegerAttr>(); 1030a292199SAart Bik if (!intAttr) { 1040a292199SAart Bik parser.emitError(parser.getNameLoc(), 1050a292199SAart Bik "expected an integral index bitwidth"); 1060a292199SAart Bik return {}; 1070a292199SAart Bik } 1080a292199SAart Bik ind = intAttr.getInt(); 1090a292199SAart Bik } else { 1100a292199SAart Bik parser.emitError(parser.getNameLoc(), "unexpected key: ") 1110a292199SAart Bik << attr.first.str(); 1120a292199SAart Bik return {}; 1130a292199SAart Bik } 1140a292199SAart Bik } 1150a292199SAart Bik // Construct struct-like storage for attribute. 1160a292199SAart Bik return parser.getChecked<SparseTensorEncodingAttr>(context, dlt, map, ptr, 1170a292199SAart Bik ind); 1180a292199SAart Bik } 1190a292199SAart Bik 1200a292199SAart Bik void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const { 1210a292199SAart Bik // Print the struct-like storage in dictionary fashion. 1220a292199SAart Bik printer << "encoding<{ dimLevelType = [ "; 1230a292199SAart Bik for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) { 1240a292199SAart Bik switch (getDimLevelType()[i]) { 1250a292199SAart Bik case DimLevelType::Dense: 1260a292199SAart Bik printer << "\"dense\""; 1270a292199SAart Bik break; 1280a292199SAart Bik case DimLevelType::Compressed: 1290a292199SAart Bik printer << "\"compressed\""; 1300a292199SAart Bik break; 1310a292199SAart Bik case DimLevelType::Singleton: 1320a292199SAart Bik printer << "\"singleton\""; 1330a292199SAart Bik break; 1340a292199SAart Bik } 1350a292199SAart Bik if (i != e - 1) 1360a292199SAart Bik printer << ", "; 1370a292199SAart Bik } 1380a292199SAart Bik printer << " ]"; 1390a292199SAart Bik if (getDimOrdering()) 1400a292199SAart Bik printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">"; 1410a292199SAart Bik printer << ", pointerBitWidth = " << getPointerBitWidth() 1420a292199SAart Bik << ", indexBitWidth = " << getIndexBitWidth() << " }>"; 1430a292199SAart Bik } 1440a292199SAart Bik 1450a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verify( 1460a292199SAart Bik function_ref<InFlightDiagnostic()> emitError, 1470a292199SAart Bik ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering, 1480a292199SAart Bik unsigned pointerBitWidth, unsigned indexBitWidth) { 1490a292199SAart Bik if (!acceptBitWidth(pointerBitWidth)) 1500a292199SAart Bik return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth; 1510a292199SAart Bik if (!acceptBitWidth(indexBitWidth)) 1520a292199SAart Bik return emitError() << "unexpected index bitwidth: " << indexBitWidth; 1530a292199SAart Bik if (dimOrdering) { 1540a292199SAart Bik if (!dimOrdering.isPermutation()) 1550a292199SAart Bik return emitError() 1560a292199SAart Bik << "expected a permutation affine map for dimension ordering"; 1570a292199SAart Bik if (dimOrdering.getNumResults() != dimLevelType.size()) 1580a292199SAart Bik return emitError() << "unexpected mismatch in ordering and dimension " 1590a292199SAart Bik "level types size"; 1600a292199SAart Bik } 1610a292199SAart Bik return success(); 1620a292199SAart Bik } 1630a292199SAart Bik 1640a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding( 1650a292199SAart Bik ArrayRef<int64_t> shape, Type elementType, 1660a292199SAart Bik function_ref<InFlightDiagnostic()> emitError) const { 1670a292199SAart Bik // Check structural integrity. 1680a292199SAart Bik if (failed(verify(emitError, getDimLevelType(), getDimOrdering(), 1690a292199SAart Bik getPointerBitWidth(), getIndexBitWidth()))) 1700a292199SAart Bik return failure(); 1710a292199SAart Bik // Check integrity with tensor type specifics. Dimension ordering is optional, 1720a292199SAart Bik // but we always should have dimension level types for the full rank. 1730a292199SAart Bik unsigned size = shape.size(); 1740a292199SAart Bik if (getDimOrdering() && getDimOrdering().getNumResults() != size) 1750a292199SAart Bik return emitError() << "expected an affine map of size " << size 1760a292199SAart Bik << " for dimension ordering"; 1770a292199SAart Bik if (getDimLevelType().size() != size) 1780a292199SAart Bik return emitError() << "expected an array of size " << size 1790a292199SAart Bik << " for dimension level types"; 1800a292199SAart Bik return success(); 1810a292199SAart Bik } 1820a292199SAart Bik 18396a23911SAart Bik SparseTensorEncodingAttr 18496a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) { 18596a23911SAart Bik if (auto ttp = type.dyn_cast<RankedTensorType>()) 18696a23911SAart Bik return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>(); 18796a23911SAart Bik return nullptr; 18896a23911SAart Bik } 18996a23911SAart Bik 1900a292199SAart Bik //===----------------------------------------------------------------------===// 19196a23911SAart Bik // TensorDialect Operations. 19296a23911SAart Bik //===----------------------------------------------------------------------===// 19396a23911SAart Bik 19496a23911SAart Bik static LogicalResult isInBounds(Value dim, Value tensor) { 19596a23911SAart Bik if (auto constantOp = dim.getDefiningOp<ConstantOp>()) { 19696a23911SAart Bik unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt(); 19796a23911SAart Bik if (d >= tensor.getType().cast<RankedTensorType>().getRank()) 19896a23911SAart Bik return failure(); 19996a23911SAart Bik } 20096a23911SAart Bik return success(); // in bounds, or symbolic 20196a23911SAart Bik } 20296a23911SAart Bik 20396a23911SAart Bik static LogicalResult isMatchingWidth(Value result, unsigned width) { 20496a23911SAart Bik Type etp = result.getType().cast<MemRefType>().getElementType(); 20596a23911SAart Bik if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width))) 20696a23911SAart Bik return success(); 20796a23911SAart Bik return failure(); 20896a23911SAart Bik } 20996a23911SAart Bik 21096a23911SAart Bik static LogicalResult verify(NewOp op) { 211*697ea09dSAart Bik if (!getSparseTensorEncoding(op.result().getType())) 21296a23911SAart Bik return op.emitError("expected a sparse tensor result"); 21396a23911SAart Bik return success(); 21496a23911SAart Bik } 21596a23911SAart Bik 216*697ea09dSAart Bik static LogicalResult verify(ConvertOp op) { 217*697ea09dSAart Bik if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { 218*697ea09dSAart Bik if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { 219*697ea09dSAart Bik assert(tp1.getRank() == tp2.getRank()); 220*697ea09dSAart Bik auto shape1 = tp1.getShape(); 221*697ea09dSAart Bik auto shape2 = tp2.getShape(); 222*697ea09dSAart Bik for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) 223*697ea09dSAart Bik if (shape1[d] != shape2[d]) 224*697ea09dSAart Bik return op.emitError() 225*697ea09dSAart Bik << "unexpected conversion mismatch in dimension " << d; 226*697ea09dSAart Bik return success(); 227*697ea09dSAart Bik } 228*697ea09dSAart Bik } 229*697ea09dSAart Bik return op.emitError("unexpected type in convert"); 230*697ea09dSAart Bik } 231*697ea09dSAart Bik 23296a23911SAart Bik static LogicalResult verify(ToPointersOp op) { 233c2415d67SAart Bik if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 23496a23911SAart Bik if (failed(isInBounds(op.dim(), op.tensor()))) 23596a23911SAart Bik return op.emitError("requested pointers dimension out of bounds"); 23696a23911SAart Bik if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) 23796a23911SAart Bik return op.emitError("unexpected type for pointers"); 23896a23911SAart Bik return success(); 23996a23911SAart Bik } 24096a23911SAart Bik return op.emitError("expected a sparse tensor to get pointers"); 24196a23911SAart Bik } 24296a23911SAart Bik 24396a23911SAart Bik static LogicalResult verify(ToIndicesOp op) { 244c2415d67SAart Bik if (auto e = getSparseTensorEncoding(op.tensor().getType())) { 24596a23911SAart Bik if (failed(isInBounds(op.dim(), op.tensor()))) 24696a23911SAart Bik return op.emitError("requested indices dimension out of bounds"); 24796a23911SAart Bik if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) 24896a23911SAart Bik return op.emitError("unexpected type for indices"); 24996a23911SAart Bik return success(); 25096a23911SAart Bik } 25196a23911SAart Bik return op.emitError("expected a sparse tensor to get indices"); 25296a23911SAart Bik } 25396a23911SAart Bik 25496a23911SAart Bik static LogicalResult verify(ToValuesOp op) { 25596a23911SAart Bik if (!getSparseTensorEncoding(op.tensor().getType())) 25696a23911SAart Bik return op.emitError("expected a sparse tensor to get values"); 25796a23911SAart Bik RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); 25896a23911SAart Bik MemRefType mtp = op.result().getType().cast<MemRefType>(); 25996a23911SAart Bik if (ttp.getElementType() != mtp.getElementType()) 26096a23911SAart Bik return op.emitError("unexpected mismatch in element types"); 26196a23911SAart Bik return success(); 26296a23911SAart Bik } 26396a23911SAart Bik 264727a63e0SAart Bik static LogicalResult verify(ToTensorOp op) { 26536b66ab9SAart Bik if (!getSparseTensorEncoding(op.result().getType())) 266727a63e0SAart Bik return op.emitError("expected a sparse tensor as result"); 26736b66ab9SAart Bik return success(); 268727a63e0SAart Bik } 269727a63e0SAart Bik 27096a23911SAart Bik //===----------------------------------------------------------------------===// 27196a23911SAart Bik // TensorDialect Methods. 2720a292199SAart Bik //===----------------------------------------------------------------------===// 2730a292199SAart Bik 274319072f4SAart Bik void SparseTensorDialect::initialize() { 2750a292199SAart Bik addAttributes< 2760a292199SAart Bik #define GET_ATTRDEF_LIST 2770a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" 2780a292199SAart Bik >(); 279319072f4SAart Bik addOperations< 280319072f4SAart Bik #define GET_OP_LIST 281319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 282319072f4SAart Bik >(); 283319072f4SAart Bik } 284319072f4SAart Bik 285319072f4SAart Bik #define GET_OP_CLASSES 286319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" 2870a292199SAart Bik 2880a292199SAart Bik Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser, 2890a292199SAart Bik Type type) const { 2900a292199SAart Bik StringRef attrTag; 2910a292199SAart Bik if (failed(parser.parseKeyword(&attrTag))) 2920a292199SAart Bik return Attribute(); 2930a292199SAart Bik Attribute attr; 2940a292199SAart Bik auto parseResult = 2950a292199SAart Bik generatedAttributeParser(getContext(), parser, attrTag, type, attr); 2960a292199SAart Bik if (parseResult.hasValue()) 2970a292199SAart Bik return attr; 2980a292199SAart Bik parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute"); 2990a292199SAart Bik return Attribute(); 3000a292199SAart Bik } 3010a292199SAart Bik 3020a292199SAart Bik void SparseTensorDialect::printAttribute(Attribute attr, 3030a292199SAart Bik DialectAsmPrinter &printer) const { 3040a292199SAart Bik if (succeeded(generatedAttributePrinter(attr, printer))) 3050a292199SAart Bik return; 3060a292199SAart Bik } 307