1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 //===----------------------------------------------------------------------===//
20 // TensorDialect Attribute Methods.
21 //===----------------------------------------------------------------------===//
22 
23 #define GET_ATTRDEF_CLASSES
24 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
25 
26 static bool acceptBitWidth(unsigned bitWidth) {
27   switch (bitWidth) {
28   case 0:
29   case 8:
30   case 16:
31   case 32:
32   case 64:
33     return true;
34   default:
35     return false;
36   }
37 }
38 
39 Attribute SparseTensorEncodingAttr::parse(MLIRContext *context,
40                                           DialectAsmParser &parser, Type type) {
41   if (failed(parser.parseLess()))
42     return {};
43   // Parse the data as a dictionary.
44   DictionaryAttr dict;
45   if (failed(parser.parseAttribute(dict)))
46     return {};
47   if (failed(parser.parseGreater()))
48     return {};
49   // Process the data from the parsed dictionary value into struct-like data.
50   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
51   AffineMap map = {};
52   unsigned ptr = 0;
53   unsigned ind = 0;
54   for (const NamedAttribute &attr : dict) {
55     if (attr.first == "dimLevelType") {
56       auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
57       if (!arrayAttr) {
58         parser.emitError(parser.getNameLoc(),
59                          "expected an array for dimension level types");
60         return {};
61       }
62       for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) {
63         auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
64         if (!strAttr) {
65           parser.emitError(parser.getNameLoc(),
66                            "expected a string value in dimension level types");
67           return {};
68         }
69         auto strVal = strAttr.getValue();
70         if (strVal == "dense") {
71           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
72         } else if (strVal == "compressed") {
73           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
74         } else if (strVal == "singleton") {
75           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
76         } else {
77           parser.emitError(parser.getNameLoc(),
78                            "unexpected dimension level type: ")
79               << strVal;
80           return {};
81         }
82       }
83     } else if (attr.first == "dimOrdering") {
84       auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
85       if (!affineAttr) {
86         parser.emitError(parser.getNameLoc(),
87                          "expected an affine map for dimension ordering");
88         return {};
89       }
90       map = affineAttr.getValue();
91     } else if (attr.first == "pointerBitWidth") {
92       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
93       if (!intAttr) {
94         parser.emitError(parser.getNameLoc(),
95                          "expected an integral pointer bitwidth");
96         return {};
97       }
98       ptr = intAttr.getInt();
99     } else if (attr.first == "indexBitWidth") {
100       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
101       if (!intAttr) {
102         parser.emitError(parser.getNameLoc(),
103                          "expected an integral index bitwidth");
104         return {};
105       }
106       ind = intAttr.getInt();
107     } else {
108       parser.emitError(parser.getNameLoc(), "unexpected key: ")
109           << attr.first.str();
110       return {};
111     }
112   }
113   // Construct struct-like storage for attribute.
114   return parser.getChecked<SparseTensorEncodingAttr>(context, dlt, map, ptr,
115                                                      ind);
116 }
117 
118 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
119   // Print the struct-like storage in dictionary fashion.
120   printer << "encoding<{ dimLevelType = [ ";
121   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
122     switch (getDimLevelType()[i]) {
123     case DimLevelType::Dense:
124       printer << "\"dense\"";
125       break;
126     case DimLevelType::Compressed:
127       printer << "\"compressed\"";
128       break;
129     case DimLevelType::Singleton:
130       printer << "\"singleton\"";
131       break;
132     }
133     if (i != e - 1)
134       printer << ", ";
135   }
136   printer << " ]";
137   if (getDimOrdering())
138     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
139   printer << ", pointerBitWidth = " << getPointerBitWidth()
140           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
141 }
142 
143 LogicalResult SparseTensorEncodingAttr::verify(
144     function_ref<InFlightDiagnostic()> emitError,
145     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
146     unsigned pointerBitWidth, unsigned indexBitWidth) {
147   if (!acceptBitWidth(pointerBitWidth))
148     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
149   if (!acceptBitWidth(indexBitWidth))
150     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
151   if (dimOrdering) {
152     if (!dimOrdering.isPermutation())
153       return emitError()
154              << "expected a permutation affine map for dimension ordering";
155     if (dimOrdering.getNumResults() != dimLevelType.size())
156       return emitError() << "unexpected mismatch in ordering and dimension "
157                             "level types size";
158   }
159   return success();
160 }
161 
162 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
163     ArrayRef<int64_t> shape, Type elementType,
164     function_ref<InFlightDiagnostic()> emitError) const {
165   // Check structural integrity.
166   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
167                     getPointerBitWidth(), getIndexBitWidth())))
168     return failure();
169   // Check integrity with tensor type specifics. Dimension ordering is optional,
170   // but we always should have dimension level types for the full rank.
171   unsigned size = shape.size();
172   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
173     return emitError() << "expected an affine map of size " << size
174                        << " for dimension ordering";
175   if (getDimLevelType().size() != size)
176     return emitError() << "expected an array of size " << size
177                        << " for dimension level types";
178   return success();
179 }
180 
181 SparseTensorEncodingAttr
182 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
183   if (auto ttp = type.dyn_cast<RankedTensorType>())
184     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
185   return nullptr;
186 }
187 
188 //===----------------------------------------------------------------------===//
189 // TensorDialect Operations.
190 //===----------------------------------------------------------------------===//
191 
192 static LogicalResult isInBounds(Value dim, Value tensor) {
193   if (auto constantOp = dim.getDefiningOp<ConstantOp>()) {
194     unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt();
195     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
196       return failure();
197   }
198   return success(); // in bounds, or symbolic
199 }
200 
201 static LogicalResult isMatchingWidth(Value result, unsigned width) {
202   Type etp = result.getType().cast<MemRefType>().getElementType();
203   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
204     return success();
205   return failure();
206 }
207 
208 static LogicalResult verify(NewOp op) {
209   if (!getSparseTensorEncoding(op.getResult().getType()))
210     return op.emitError("expected a sparse tensor result");
211   return success();
212 }
213 
214 static LogicalResult verify(ToPointersOp op) {
215   if (failed(isInBounds(op.dim(), op.tensor())))
216     return op.emitError("requested pointers dimension out of bounds");
217   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
218     if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
219       return op.emitError("unexpected type for pointers");
220     return success();
221   }
222   return op.emitError("expected a sparse tensor to get pointers");
223 }
224 
225 static LogicalResult verify(ToIndicesOp op) {
226   if (failed(isInBounds(op.dim(), op.tensor())))
227     return op.emitError("requested indices dimension out of bounds");
228   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
229     if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
230       return op.emitError("unexpected type for indices");
231     return success();
232   }
233   return op.emitError("expected a sparse tensor to get indices");
234 }
235 
236 static LogicalResult verify(ToValuesOp op) {
237   if (!getSparseTensorEncoding(op.tensor().getType()))
238     return op.emitError("expected a sparse tensor to get values");
239   RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
240   MemRefType mtp = op.result().getType().cast<MemRefType>();
241   if (ttp.getElementType() != mtp.getElementType())
242     return op.emitError("unexpected mismatch in element types");
243   return success();
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // TensorDialect Methods.
248 //===----------------------------------------------------------------------===//
249 
250 void SparseTensorDialect::initialize() {
251   addAttributes<
252 #define GET_ATTRDEF_LIST
253 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
254       >();
255   addOperations<
256 #define GET_OP_LIST
257 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
258       >();
259 }
260 
261 #define GET_OP_CLASSES
262 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
263 
264 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser,
265                                               Type type) const {
266   StringRef attrTag;
267   if (failed(parser.parseKeyword(&attrTag)))
268     return Attribute();
269   Attribute attr;
270   auto parseResult =
271       generatedAttributeParser(getContext(), parser, attrTag, type, attr);
272   if (parseResult.hasValue())
273     return attr;
274   parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute");
275   return Attribute();
276 }
277 
278 void SparseTensorDialect::printAttribute(Attribute attr,
279                                          DialectAsmPrinter &printer) const {
280   if (succeeded(generatedAttributePrinter(attr, printer)))
281     return;
282 }
283