1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2319072f4SAart Bik //
3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6319072f4SAart Bik //
7319072f4SAart Bik //===----------------------------------------------------------------------===//
8319072f4SAart Bik 
9*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1196a23911SAart Bik #include "mlir/Dialect/StandardOps/IR/Ops.h"
12319072f4SAart Bik #include "mlir/IR/Builders.h"
130a292199SAart Bik #include "mlir/IR/DialectImplementation.h"
14319072f4SAart Bik #include "mlir/IR/OpImplementation.h"
150a292199SAart Bik #include "llvm/ADT/TypeSwitch.h"
16319072f4SAart Bik 
17319072f4SAart Bik using namespace mlir;
18319072f4SAart Bik using namespace mlir::sparse_tensor;
19319072f4SAart Bik 
20485cc55eSStella Laurenzo #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
21485cc55eSStella Laurenzo 
220a292199SAart Bik //===----------------------------------------------------------------------===//
2396a23911SAart Bik // TensorDialect Attribute Methods.
240a292199SAart Bik //===----------------------------------------------------------------------===//
250a292199SAart Bik 
260a292199SAart Bik #define GET_ATTRDEF_CLASSES
270a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
280a292199SAart Bik 
290a292199SAart Bik static bool acceptBitWidth(unsigned bitWidth) {
300a292199SAart Bik   switch (bitWidth) {
310a292199SAart Bik   case 0:
320a292199SAart Bik   case 8:
330a292199SAart Bik   case 16:
340a292199SAart Bik   case 32:
350a292199SAart Bik   case 64:
360a292199SAart Bik     return true;
370a292199SAart Bik   default:
380a292199SAart Bik     return false;
390a292199SAart Bik   }
400a292199SAart Bik }
410a292199SAart Bik 
42fb093c83SChris Lattner Attribute SparseTensorEncodingAttr::parse(DialectAsmParser &parser, Type type) {
430a292199SAart Bik   if (failed(parser.parseLess()))
440a292199SAart Bik     return {};
450a292199SAart Bik   // Parse the data as a dictionary.
460a292199SAart Bik   DictionaryAttr dict;
470a292199SAart Bik   if (failed(parser.parseAttribute(dict)))
480a292199SAart Bik     return {};
490a292199SAart Bik   if (failed(parser.parseGreater()))
500a292199SAart Bik     return {};
510a292199SAart Bik   // Process the data from the parsed dictionary value into struct-like data.
520a292199SAart Bik   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
530a292199SAart Bik   AffineMap map = {};
540a292199SAart Bik   unsigned ptr = 0;
550a292199SAart Bik   unsigned ind = 0;
560a292199SAart Bik   for (const NamedAttribute &attr : dict) {
570a292199SAart Bik     if (attr.first == "dimLevelType") {
580a292199SAart Bik       auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
590a292199SAart Bik       if (!arrayAttr) {
600a292199SAart Bik         parser.emitError(parser.getNameLoc(),
610a292199SAart Bik                          "expected an array for dimension level types");
620a292199SAart Bik         return {};
630a292199SAart Bik       }
640a292199SAart Bik       for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) {
650a292199SAart Bik         auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
660a292199SAart Bik         if (!strAttr) {
670a292199SAart Bik           parser.emitError(parser.getNameLoc(),
680a292199SAart Bik                            "expected a string value in dimension level types");
690a292199SAart Bik           return {};
700a292199SAart Bik         }
710a292199SAart Bik         auto strVal = strAttr.getValue();
720a292199SAart Bik         if (strVal == "dense") {
730a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
740a292199SAart Bik         } else if (strVal == "compressed") {
750a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
760a292199SAart Bik         } else if (strVal == "singleton") {
770a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
780a292199SAart Bik         } else {
790a292199SAart Bik           parser.emitError(parser.getNameLoc(),
800a292199SAart Bik                            "unexpected dimension level type: ")
810a292199SAart Bik               << strVal;
820a292199SAart Bik           return {};
830a292199SAart Bik         }
840a292199SAart Bik       }
850a292199SAart Bik     } else if (attr.first == "dimOrdering") {
860a292199SAart Bik       auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
870a292199SAart Bik       if (!affineAttr) {
880a292199SAart Bik         parser.emitError(parser.getNameLoc(),
890a292199SAart Bik                          "expected an affine map for dimension ordering");
900a292199SAart Bik         return {};
910a292199SAart Bik       }
920a292199SAart Bik       map = affineAttr.getValue();
930a292199SAart Bik     } else if (attr.first == "pointerBitWidth") {
940a292199SAart Bik       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
950a292199SAart Bik       if (!intAttr) {
960a292199SAart Bik         parser.emitError(parser.getNameLoc(),
970a292199SAart Bik                          "expected an integral pointer bitwidth");
980a292199SAart Bik         return {};
990a292199SAart Bik       }
1000a292199SAart Bik       ptr = intAttr.getInt();
1010a292199SAart Bik     } else if (attr.first == "indexBitWidth") {
1020a292199SAart Bik       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
1030a292199SAart Bik       if (!intAttr) {
1040a292199SAart Bik         parser.emitError(parser.getNameLoc(),
1050a292199SAart Bik                          "expected an integral index bitwidth");
1060a292199SAart Bik         return {};
1070a292199SAart Bik       }
1080a292199SAart Bik       ind = intAttr.getInt();
1090a292199SAart Bik     } else {
1100a292199SAart Bik       parser.emitError(parser.getNameLoc(), "unexpected key: ")
1110a292199SAart Bik           << attr.first.str();
1120a292199SAart Bik       return {};
1130a292199SAart Bik     }
1140a292199SAart Bik   }
1150a292199SAart Bik   // Construct struct-like storage for attribute.
116fb093c83SChris Lattner   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
117fb093c83SChris Lattner                                                      map, ptr, ind);
1180a292199SAart Bik }
1190a292199SAart Bik 
1200a292199SAart Bik void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
1210a292199SAart Bik   // Print the struct-like storage in dictionary fashion.
1220a292199SAart Bik   printer << "encoding<{ dimLevelType = [ ";
1230a292199SAart Bik   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
1240a292199SAart Bik     switch (getDimLevelType()[i]) {
1250a292199SAart Bik     case DimLevelType::Dense:
1260a292199SAart Bik       printer << "\"dense\"";
1270a292199SAart Bik       break;
1280a292199SAart Bik     case DimLevelType::Compressed:
1290a292199SAart Bik       printer << "\"compressed\"";
1300a292199SAart Bik       break;
1310a292199SAart Bik     case DimLevelType::Singleton:
1320a292199SAart Bik       printer << "\"singleton\"";
1330a292199SAart Bik       break;
1340a292199SAart Bik     }
1350a292199SAart Bik     if (i != e - 1)
1360a292199SAart Bik       printer << ", ";
1370a292199SAart Bik   }
1380a292199SAart Bik   printer << " ]";
1390a292199SAart Bik   if (getDimOrdering())
1400a292199SAart Bik     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
1410a292199SAart Bik   printer << ", pointerBitWidth = " << getPointerBitWidth()
1420a292199SAart Bik           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
1430a292199SAart Bik }
1440a292199SAart Bik 
1450a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verify(
1460a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError,
1470a292199SAart Bik     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
1480a292199SAart Bik     unsigned pointerBitWidth, unsigned indexBitWidth) {
1490a292199SAart Bik   if (!acceptBitWidth(pointerBitWidth))
1500a292199SAart Bik     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
1510a292199SAart Bik   if (!acceptBitWidth(indexBitWidth))
1520a292199SAart Bik     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
1530a292199SAart Bik   if (dimOrdering) {
1540a292199SAart Bik     if (!dimOrdering.isPermutation())
1550a292199SAart Bik       return emitError()
1560a292199SAart Bik              << "expected a permutation affine map for dimension ordering";
1570a292199SAart Bik     if (dimOrdering.getNumResults() != dimLevelType.size())
1580a292199SAart Bik       return emitError() << "unexpected mismatch in ordering and dimension "
1590a292199SAart Bik                             "level types size";
1600a292199SAart Bik   }
1610a292199SAart Bik   return success();
1620a292199SAart Bik }
1630a292199SAart Bik 
1640a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding(
1650a292199SAart Bik     ArrayRef<int64_t> shape, Type elementType,
1660a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError) const {
1670a292199SAart Bik   // Check structural integrity.
1680a292199SAart Bik   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
1690a292199SAart Bik                     getPointerBitWidth(), getIndexBitWidth())))
1700a292199SAart Bik     return failure();
1710a292199SAart Bik   // Check integrity with tensor type specifics. Dimension ordering is optional,
1720a292199SAart Bik   // but we always should have dimension level types for the full rank.
1730a292199SAart Bik   unsigned size = shape.size();
1740a292199SAart Bik   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
1750a292199SAart Bik     return emitError() << "expected an affine map of size " << size
1760a292199SAart Bik                        << " for dimension ordering";
1770a292199SAart Bik   if (getDimLevelType().size() != size)
1780a292199SAart Bik     return emitError() << "expected an array of size " << size
1790a292199SAart Bik                        << " for dimension level types";
1800a292199SAart Bik   return success();
1810a292199SAart Bik }
1820a292199SAart Bik 
18396a23911SAart Bik SparseTensorEncodingAttr
18496a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
18596a23911SAart Bik   if (auto ttp = type.dyn_cast<RankedTensorType>())
18696a23911SAart Bik     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
18796a23911SAart Bik   return nullptr;
18896a23911SAart Bik }
18996a23911SAart Bik 
1900a292199SAart Bik //===----------------------------------------------------------------------===//
19196a23911SAart Bik // TensorDialect Operations.
19296a23911SAart Bik //===----------------------------------------------------------------------===//
19396a23911SAart Bik 
19496a23911SAart Bik static LogicalResult isInBounds(Value dim, Value tensor) {
195*a54f4eaeSMogball   if (auto constantOp = dim.getDefiningOp<arith::ConstantOp>()) {
196*a54f4eaeSMogball     unsigned d = constantOp.value().cast<IntegerAttr>().getInt();
19796a23911SAart Bik     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
19896a23911SAart Bik       return failure();
19996a23911SAart Bik   }
20096a23911SAart Bik   return success(); // in bounds, or symbolic
20196a23911SAart Bik }
20296a23911SAart Bik 
20396a23911SAart Bik static LogicalResult isMatchingWidth(Value result, unsigned width) {
20496a23911SAart Bik   Type etp = result.getType().cast<MemRefType>().getElementType();
20596a23911SAart Bik   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
20696a23911SAart Bik     return success();
20796a23911SAart Bik   return failure();
20896a23911SAart Bik }
20996a23911SAart Bik 
21096a23911SAart Bik static LogicalResult verify(NewOp op) {
211697ea09dSAart Bik   if (!getSparseTensorEncoding(op.result().getType()))
21296a23911SAart Bik     return op.emitError("expected a sparse tensor result");
21396a23911SAart Bik   return success();
21496a23911SAart Bik }
21596a23911SAart Bik 
216697ea09dSAart Bik static LogicalResult verify(ConvertOp op) {
217697ea09dSAart Bik   if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
218697ea09dSAart Bik     if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
219697ea09dSAart Bik       assert(tp1.getRank() == tp2.getRank());
220697ea09dSAart Bik       auto shape1 = tp1.getShape();
221697ea09dSAart Bik       auto shape2 = tp2.getShape();
22205c7f450SAart Bik       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
223697ea09dSAart Bik         if (shape1[d] != shape2[d])
224697ea09dSAart Bik           return op.emitError()
225697ea09dSAart Bik                  << "unexpected conversion mismatch in dimension " << d;
22605c7f450SAart Bik       }
227697ea09dSAart Bik       return success();
228697ea09dSAart Bik     }
229697ea09dSAart Bik   }
230697ea09dSAart Bik   return op.emitError("unexpected type in convert");
231697ea09dSAart Bik }
232697ea09dSAart Bik 
233066d786cSAart Bik OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
234066d786cSAart Bik   if (getType() == source().getType())
235066d786cSAart Bik     return source();
236066d786cSAart Bik   return {};
237066d786cSAart Bik }
238066d786cSAart Bik 
23916b8f4ddSAart Bik static LogicalResult verify(ReleaseOp op) {
24016b8f4ddSAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
24116b8f4ddSAart Bik     return op.emitError("expected a sparse tensor to release");
24216b8f4ddSAart Bik   return success();
24316b8f4ddSAart Bik }
24416b8f4ddSAart Bik 
24596a23911SAart Bik static LogicalResult verify(ToPointersOp op) {
246c2415d67SAart Bik   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
24796a23911SAart Bik     if (failed(isInBounds(op.dim(), op.tensor())))
24896a23911SAart Bik       return op.emitError("requested pointers dimension out of bounds");
24996a23911SAart Bik     if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
25096a23911SAart Bik       return op.emitError("unexpected type for pointers");
25196a23911SAart Bik     return success();
25296a23911SAart Bik   }
25396a23911SAart Bik   return op.emitError("expected a sparse tensor to get pointers");
25496a23911SAart Bik }
25596a23911SAart Bik 
25696a23911SAart Bik static LogicalResult verify(ToIndicesOp op) {
257c2415d67SAart Bik   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
25896a23911SAart Bik     if (failed(isInBounds(op.dim(), op.tensor())))
25996a23911SAart Bik       return op.emitError("requested indices dimension out of bounds");
26096a23911SAart Bik     if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
26196a23911SAart Bik       return op.emitError("unexpected type for indices");
26296a23911SAart Bik     return success();
26396a23911SAart Bik   }
26496a23911SAart Bik   return op.emitError("expected a sparse tensor to get indices");
26596a23911SAart Bik }
26696a23911SAart Bik 
26796a23911SAart Bik static LogicalResult verify(ToValuesOp op) {
26896a23911SAart Bik   if (!getSparseTensorEncoding(op.tensor().getType()))
26996a23911SAart Bik     return op.emitError("expected a sparse tensor to get values");
27096a23911SAart Bik   RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
27196a23911SAart Bik   MemRefType mtp = op.result().getType().cast<MemRefType>();
27296a23911SAart Bik   if (ttp.getElementType() != mtp.getElementType())
27396a23911SAart Bik     return op.emitError("unexpected mismatch in element types");
27496a23911SAart Bik   return success();
27596a23911SAart Bik }
27696a23911SAart Bik 
277727a63e0SAart Bik static LogicalResult verify(ToTensorOp op) {
27836b66ab9SAart Bik   if (!getSparseTensorEncoding(op.result().getType()))
279727a63e0SAart Bik     return op.emitError("expected a sparse tensor as result");
28036b66ab9SAart Bik   return success();
281727a63e0SAart Bik }
282727a63e0SAart Bik 
28396a23911SAart Bik //===----------------------------------------------------------------------===//
28496a23911SAart Bik // TensorDialect Methods.
2850a292199SAart Bik //===----------------------------------------------------------------------===//
2860a292199SAart Bik 
287319072f4SAart Bik void SparseTensorDialect::initialize() {
2880a292199SAart Bik   addAttributes<
2890a292199SAart Bik #define GET_ATTRDEF_LIST
2900a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
2910a292199SAart Bik       >();
292319072f4SAart Bik   addOperations<
293319072f4SAart Bik #define GET_OP_LIST
294319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
295319072f4SAart Bik       >();
296319072f4SAart Bik }
297319072f4SAart Bik 
298319072f4SAart Bik #define GET_OP_CLASSES
299319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
3000a292199SAart Bik 
3010a292199SAart Bik Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser,
3020a292199SAart Bik                                               Type type) const {
3030a292199SAart Bik   StringRef attrTag;
3040a292199SAart Bik   if (failed(parser.parseKeyword(&attrTag)))
3050a292199SAart Bik     return Attribute();
3060a292199SAart Bik   Attribute attr;
307fb093c83SChris Lattner   auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
3080a292199SAart Bik   if (parseResult.hasValue())
3090a292199SAart Bik     return attr;
3100a292199SAart Bik   parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute");
3110a292199SAart Bik   return Attribute();
3120a292199SAart Bik }
3130a292199SAart Bik 
3140a292199SAart Bik void SparseTensorDialect::printAttribute(Attribute attr,
3150a292199SAart Bik                                          DialectAsmPrinter &printer) const {
3160a292199SAart Bik   if (succeeded(generatedAttributePrinter(attr, printer)))
3170a292199SAart Bik     return;
3180a292199SAart Bik }
319