1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2319072f4SAart Bik //
3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6319072f4SAart Bik //
7319072f4SAart Bik //===----------------------------------------------------------------------===//
8319072f4SAart Bik 
9319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10319072f4SAart Bik #include "mlir/IR/Builders.h"
110a292199SAart Bik #include "mlir/IR/DialectImplementation.h"
12*65e7cd13SRiver Riddle #include "mlir/IR/Matchers.h"
13319072f4SAart Bik #include "mlir/IR/OpImplementation.h"
140a292199SAart Bik #include "llvm/ADT/TypeSwitch.h"
15319072f4SAart Bik 
16319072f4SAart Bik using namespace mlir;
17319072f4SAart Bik using namespace mlir::sparse_tensor;
18319072f4SAart Bik 
19485cc55eSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
20485cc55eSStella Laurenzo 
210a292199SAart Bik //===----------------------------------------------------------------------===//
2296a23911SAart Bik // TensorDialect Attribute Methods.
230a292199SAart Bik //===----------------------------------------------------------------------===//
240a292199SAart Bik 
250a292199SAart Bik #define GET_ATTRDEF_CLASSES
260a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
270a292199SAart Bik 
280a292199SAart Bik static bool acceptBitWidth(unsigned bitWidth) {
290a292199SAart Bik   switch (bitWidth) {
300a292199SAart Bik   case 0:
310a292199SAart Bik   case 8:
320a292199SAart Bik   case 16:
330a292199SAart Bik   case 32:
340a292199SAart Bik   case 64:
350a292199SAart Bik     return true;
360a292199SAart Bik   default:
370a292199SAart Bik     return false;
380a292199SAart Bik   }
390a292199SAart Bik }
400a292199SAart Bik 
41f97e72aaSMehdi Amini Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
420a292199SAart Bik   if (failed(parser.parseLess()))
430a292199SAart Bik     return {};
440a292199SAart Bik   // Parse the data as a dictionary.
450a292199SAart Bik   DictionaryAttr dict;
460a292199SAart Bik   if (failed(parser.parseAttribute(dict)))
470a292199SAart Bik     return {};
480a292199SAart Bik   if (failed(parser.parseGreater()))
490a292199SAart Bik     return {};
500a292199SAart Bik   // Process the data from the parsed dictionary value into struct-like data.
510a292199SAart Bik   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
520a292199SAart Bik   AffineMap map = {};
530a292199SAart Bik   unsigned ptr = 0;
540a292199SAart Bik   unsigned ind = 0;
550a292199SAart Bik   for (const NamedAttribute &attr : dict) {
560c7890c8SRiver Riddle     if (attr.getName() == "dimLevelType") {
570c7890c8SRiver Riddle       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
580a292199SAart Bik       if (!arrayAttr) {
590a292199SAart Bik         parser.emitError(parser.getNameLoc(),
600a292199SAart Bik                          "expected an array for dimension level types");
610a292199SAart Bik         return {};
620a292199SAart Bik       }
63e5639b3fSMehdi Amini       for (auto i : arrayAttr) {
64e5639b3fSMehdi Amini         auto strAttr = i.dyn_cast<StringAttr>();
650a292199SAart Bik         if (!strAttr) {
660a292199SAart Bik           parser.emitError(parser.getNameLoc(),
670a292199SAart Bik                            "expected a string value in dimension level types");
680a292199SAart Bik           return {};
690a292199SAart Bik         }
700a292199SAart Bik         auto strVal = strAttr.getValue();
710a292199SAart Bik         if (strVal == "dense") {
720a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
730a292199SAart Bik         } else if (strVal == "compressed") {
740a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
750a292199SAart Bik         } else if (strVal == "singleton") {
760a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
770a292199SAart Bik         } else {
780a292199SAart Bik           parser.emitError(parser.getNameLoc(),
790a292199SAart Bik                            "unexpected dimension level type: ")
800a292199SAart Bik               << strVal;
810a292199SAart Bik           return {};
820a292199SAart Bik         }
830a292199SAart Bik       }
840c7890c8SRiver Riddle     } else if (attr.getName() == "dimOrdering") {
850c7890c8SRiver Riddle       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
860a292199SAart Bik       if (!affineAttr) {
870a292199SAart Bik         parser.emitError(parser.getNameLoc(),
880a292199SAart Bik                          "expected an affine map for dimension ordering");
890a292199SAart Bik         return {};
900a292199SAart Bik       }
910a292199SAart Bik       map = affineAttr.getValue();
920c7890c8SRiver Riddle     } else if (attr.getName() == "pointerBitWidth") {
930c7890c8SRiver Riddle       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
940a292199SAart Bik       if (!intAttr) {
950a292199SAart Bik         parser.emitError(parser.getNameLoc(),
960a292199SAart Bik                          "expected an integral pointer bitwidth");
970a292199SAart Bik         return {};
980a292199SAart Bik       }
990a292199SAart Bik       ptr = intAttr.getInt();
1000c7890c8SRiver Riddle     } else if (attr.getName() == "indexBitWidth") {
1010c7890c8SRiver Riddle       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
1020a292199SAart Bik       if (!intAttr) {
1030a292199SAart Bik         parser.emitError(parser.getNameLoc(),
1040a292199SAart Bik                          "expected an integral index bitwidth");
1050a292199SAart Bik         return {};
1060a292199SAart Bik       }
1070a292199SAart Bik       ind = intAttr.getInt();
1080a292199SAart Bik     } else {
1090a292199SAart Bik       parser.emitError(parser.getNameLoc(), "unexpected key: ")
1100c7890c8SRiver Riddle           << attr.getName().strref();
1110a292199SAart Bik       return {};
1120a292199SAart Bik     }
1130a292199SAart Bik   }
1140a292199SAart Bik   // Construct struct-like storage for attribute.
115fb093c83SChris Lattner   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
116fb093c83SChris Lattner                                                      map, ptr, ind);
1170a292199SAart Bik }
1180a292199SAart Bik 
119f97e72aaSMehdi Amini void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
1200a292199SAart Bik   // Print the struct-like storage in dictionary fashion.
121f30a8a6fSMehdi Amini   printer << "<{ dimLevelType = [ ";
1220a292199SAart Bik   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
1230a292199SAart Bik     switch (getDimLevelType()[i]) {
1240a292199SAart Bik     case DimLevelType::Dense:
1250a292199SAart Bik       printer << "\"dense\"";
1260a292199SAart Bik       break;
1270a292199SAart Bik     case DimLevelType::Compressed:
1280a292199SAart Bik       printer << "\"compressed\"";
1290a292199SAart Bik       break;
1300a292199SAart Bik     case DimLevelType::Singleton:
1310a292199SAart Bik       printer << "\"singleton\"";
1320a292199SAart Bik       break;
1330a292199SAart Bik     }
1340a292199SAart Bik     if (i != e - 1)
1350a292199SAart Bik       printer << ", ";
1360a292199SAart Bik   }
1370a292199SAart Bik   printer << " ]";
1380a292199SAart Bik   if (getDimOrdering())
1390a292199SAart Bik     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
1400a292199SAart Bik   printer << ", pointerBitWidth = " << getPointerBitWidth()
1410a292199SAart Bik           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
1420a292199SAart Bik }
1430a292199SAart Bik 
1440a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verify(
1450a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError,
1460a292199SAart Bik     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
1470a292199SAart Bik     unsigned pointerBitWidth, unsigned indexBitWidth) {
1480a292199SAart Bik   if (!acceptBitWidth(pointerBitWidth))
1490a292199SAart Bik     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
1500a292199SAart Bik   if (!acceptBitWidth(indexBitWidth))
1510a292199SAart Bik     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
1520a292199SAart Bik   if (dimOrdering) {
1530a292199SAart Bik     if (!dimOrdering.isPermutation())
1540a292199SAart Bik       return emitError()
1550a292199SAart Bik              << "expected a permutation affine map for dimension ordering";
1560a292199SAart Bik     if (dimOrdering.getNumResults() != dimLevelType.size())
1570a292199SAart Bik       return emitError() << "unexpected mismatch in ordering and dimension "
1580a292199SAart Bik                             "level types size";
1590a292199SAart Bik   }
1600a292199SAart Bik   return success();
1610a292199SAart Bik }
1620a292199SAart Bik 
1630a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding(
1640a292199SAart Bik     ArrayRef<int64_t> shape, Type elementType,
1650a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError) const {
1660a292199SAart Bik   // Check structural integrity.
1670a292199SAart Bik   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
1680a292199SAart Bik                     getPointerBitWidth(), getIndexBitWidth())))
1690a292199SAart Bik     return failure();
1700a292199SAart Bik   // Check integrity with tensor type specifics. Dimension ordering is optional,
1710a292199SAart Bik   // but we always should have dimension level types for the full rank.
1720a292199SAart Bik   unsigned size = shape.size();
1734aa9b398SAart Bik   if (size == 0)
1744aa9b398SAart Bik     return emitError() << "expected non-scalar sparse tensor";
1750a292199SAart Bik   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
1760a292199SAart Bik     return emitError() << "expected an affine map of size " << size
1770a292199SAart Bik                        << " for dimension ordering";
1780a292199SAart Bik   if (getDimLevelType().size() != size)
1790a292199SAart Bik     return emitError() << "expected an array of size " << size
1800a292199SAart Bik                        << " for dimension level types";
1810a292199SAart Bik   return success();
1820a292199SAart Bik }
1830a292199SAart Bik 
18496a23911SAart Bik SparseTensorEncodingAttr
18596a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
18696a23911SAart Bik   if (auto ttp = type.dyn_cast<RankedTensorType>())
18796a23911SAart Bik     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
18896a23911SAart Bik   return nullptr;
18996a23911SAart Bik }
19096a23911SAart Bik 
1910a292199SAart Bik //===----------------------------------------------------------------------===//
19296a23911SAart Bik // TensorDialect Operations.
19396a23911SAart Bik //===----------------------------------------------------------------------===//
19496a23911SAart Bik 
19596a23911SAart Bik static LogicalResult isInBounds(Value dim, Value tensor) {
196*65e7cd13SRiver Riddle   IntegerAttr constantAttr;
197*65e7cd13SRiver Riddle   if (matchPattern(dim, m_Constant(&constantAttr))) {
198*65e7cd13SRiver Riddle     unsigned d = constantAttr.getInt();
19996a23911SAart Bik     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
20096a23911SAart Bik       return failure();
20196a23911SAart Bik   }
20296a23911SAart Bik   return success(); // in bounds, or symbolic
20396a23911SAart Bik }
20496a23911SAart Bik 
20596a23911SAart Bik static LogicalResult isMatchingWidth(Value result, unsigned width) {
20696a23911SAart Bik   Type etp = result.getType().cast<MemRefType>().getElementType();
20796a23911SAart Bik   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
20896a23911SAart Bik     return success();
20996a23911SAart Bik   return failure();
21096a23911SAart Bik }
21196a23911SAart Bik 
21296a23911SAart Bik static LogicalResult verify(NewOp op) {
213697ea09dSAart Bik   if (!getSparseTensorEncoding(op.result().getType()))
21496a23911SAart Bik     return op.emitError("expected a sparse tensor result");
21596a23911SAart Bik   return success();
21696a23911SAart Bik }
21796a23911SAart Bik 
21835517a25SAart Bik static LogicalResult verify(InitOp op) {
21935517a25SAart Bik   if (!getSparseTensorEncoding(op.result().getType()))
22035517a25SAart Bik     return op.emitError("expected a sparse tensor result");
22135517a25SAart Bik   RankedTensorType ttp = op.getType().cast<RankedTensorType>();
22235517a25SAart Bik   unsigned rank = ttp.getRank();
22335517a25SAart Bik   if (rank != op.sizes().size())
22435517a25SAart Bik     return op.emitError("unexpected mismatch between tensor rank and sizes: ")
22535517a25SAart Bik            << rank << " vs. " << op.sizes().size();
22635517a25SAart Bik   auto shape = ttp.getShape();
22735517a25SAart Bik   for (unsigned i = 0; i < rank; i++) {
22835517a25SAart Bik     if (shape[i] == ShapedType::kDynamicSize)
22935517a25SAart Bik       continue;
230*65e7cd13SRiver Riddle     IntegerAttr constantAttr;
231*65e7cd13SRiver Riddle     if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) ||
232*65e7cd13SRiver Riddle         constantAttr.getInt() != shape[i]) {
23335517a25SAart Bik       return op.emitError("unexpected mismatch with static dimension size ")
23435517a25SAart Bik              << shape[i];
23535517a25SAart Bik     }
236*65e7cd13SRiver Riddle   }
23735517a25SAart Bik   return success();
23835517a25SAart Bik }
23935517a25SAart Bik 
240697ea09dSAart Bik static LogicalResult verify(ConvertOp op) {
241697ea09dSAart Bik   if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
242697ea09dSAart Bik     if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
2431e6ef0cfSAart Bik       if (tp1.getRank() != tp2.getRank())
2441e6ef0cfSAart Bik         return op.emitError("unexpected conversion mismatch in rank");
245697ea09dSAart Bik       auto shape1 = tp1.getShape();
246697ea09dSAart Bik       auto shape2 = tp2.getShape();
2479d1db3d4SAart Bik       // Accept size matches between the source and the destination type
2489d1db3d4SAart Bik       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
2499d1db3d4SAart Bik       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
25005c7f450SAart Bik       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
2519d1db3d4SAart Bik         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
25235517a25SAart Bik           return op.emitError("unexpected conversion mismatch in dimension ")
25335517a25SAart Bik                  << d;
25405c7f450SAart Bik       }
255697ea09dSAart Bik       return success();
256697ea09dSAart Bik     }
257697ea09dSAart Bik   }
258697ea09dSAart Bik   return op.emitError("unexpected type in convert");
259697ea09dSAart Bik }
260697ea09dSAart Bik 
261066d786cSAart Bik OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
262066d786cSAart Bik   if (getType() == source().getType())
263066d786cSAart Bik     return source();
264066d786cSAart Bik   return {};
265066d786cSAart Bik }
266066d786cSAart Bik 
26796a23911SAart Bik static LogicalResult verify(ToPointersOp op) {
268c2415d67SAart Bik   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
26996a23911SAart Bik     if (failed(isInBounds(op.dim(), op.tensor())))
27096a23911SAart Bik       return op.emitError("requested pointers dimension out of bounds");
27196a23911SAart Bik     if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
27296a23911SAart Bik       return op.emitError("unexpected type for pointers");
27396a23911SAart Bik     return success();
27496a23911SAart Bik   }
27596a23911SAart Bik   return op.emitError("expected a sparse tensor to get pointers");
27696a23911SAart Bik }
27796a23911SAart Bik 
27896a23911SAart Bik static LogicalResult verify(ToIndicesOp op) {
279c2415d67SAart Bik   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
28096a23911SAart Bik     if (failed(isInBounds(op.dim(), op.tensor())))
28196a23911SAart Bik       return op.emitError("requested indices dimension out of bounds");
28296a23911SAart Bik     if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
28396a23911SAart Bik       return op.emitError("unexpected type for indices");
28496a23911SAart Bik     return success();
28596a23911SAart Bik   }
28696a23911SAart Bik   return op.emitError("expected a sparse tensor to get indices");
28796a23911SAart Bik }
28896a23911SAart Bik 
28996a23911SAart Bik static LogicalResult verify(ToValuesOp op) {
29096a23911SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
29196a23911SAart Bik     return op.emitError("expected a sparse tensor to get values");
29296a23911SAart Bik   RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
29396a23911SAart Bik   MemRefType mtp = op.result().getType().cast<MemRefType>();
29496a23911SAart Bik   if (ttp.getElementType() != mtp.getElementType())
29596a23911SAart Bik     return op.emitError("unexpected mismatch in element types");
29696a23911SAart Bik   return success();
29796a23911SAart Bik }
29896a23911SAart Bik 
299f66e5769SAart Bik //===----------------------------------------------------------------------===//
300f66e5769SAart Bik // TensorDialect Management Operations.
301f66e5769SAart Bik //===----------------------------------------------------------------------===//
302f66e5769SAart Bik 
303f66e5769SAart Bik static LogicalResult verify(LexInsertOp op) {
304f66e5769SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
305f66e5769SAart Bik     return op.emitError("expected a sparse tensor for insertion");
306f66e5769SAart Bik   return success();
307f66e5769SAart Bik }
308f66e5769SAart Bik 
3094f2ec7f9SAart Bik static LogicalResult verify(ExpandOp op) {
3104f2ec7f9SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
3114f2ec7f9SAart Bik     return op.emitError("expected a sparse tensor for expansion");
3124f2ec7f9SAart Bik   return success();
3134f2ec7f9SAart Bik }
3144f2ec7f9SAart Bik 
3154f2ec7f9SAart Bik static LogicalResult verify(CompressOp op) {
3164f2ec7f9SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
3174f2ec7f9SAart Bik     return op.emitError("expected a sparse tensor for compression");
3184f2ec7f9SAart Bik   return success();
3194f2ec7f9SAart Bik }
3204f2ec7f9SAart Bik 
321f66e5769SAart Bik static LogicalResult verify(LoadOp op) {
322f66e5769SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
323f66e5769SAart Bik     return op.emitError("expected a sparse tensor to materialize");
324f66e5769SAart Bik   return success();
325f66e5769SAart Bik }
326f66e5769SAart Bik 
327f66e5769SAart Bik static LogicalResult verify(ReleaseOp op) {
328f66e5769SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
329f66e5769SAart Bik     return op.emitError("expected a sparse tensor to release");
33036b66ab9SAart Bik   return success();
331727a63e0SAart Bik }
332727a63e0SAart Bik 
333efa15f41SAart Bik static LogicalResult verify(OutOp op) {
334efa15f41SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
335efa15f41SAart Bik     return op.emitError("expected a sparse tensor for output");
336efa15f41SAart Bik   return success();
337efa15f41SAart Bik }
338efa15f41SAart Bik 
33996a23911SAart Bik //===----------------------------------------------------------------------===//
34096a23911SAart Bik // TensorDialect Methods.
3410a292199SAart Bik //===----------------------------------------------------------------------===//
3420a292199SAart Bik 
343319072f4SAart Bik void SparseTensorDialect::initialize() {
3440a292199SAart Bik   addAttributes<
3450a292199SAart Bik #define GET_ATTRDEF_LIST
3460a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
3470a292199SAart Bik       >();
348319072f4SAart Bik   addOperations<
349319072f4SAart Bik #define GET_OP_LIST
350319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
351319072f4SAart Bik       >();
352319072f4SAart Bik }
353319072f4SAart Bik 
354319072f4SAart Bik #define GET_OP_CLASSES
355319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
356