1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 #include "mlir/Dialect/StandardOps/IR/Ops.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::sparse_tensor;
18 
19 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
20 
21 //===----------------------------------------------------------------------===//
22 // TensorDialect Attribute Methods.
23 //===----------------------------------------------------------------------===//
24 
25 #define GET_ATTRDEF_CLASSES
26 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
27 
28 static bool acceptBitWidth(unsigned bitWidth) {
29   switch (bitWidth) {
30   case 0:
31   case 8:
32   case 16:
33   case 32:
34   case 64:
35     return true;
36   default:
37     return false;
38   }
39 }
40 
41 Attribute SparseTensorEncodingAttr::parse(DialectAsmParser &parser, Type type) {
42   if (failed(parser.parseLess()))
43     return {};
44   // Parse the data as a dictionary.
45   DictionaryAttr dict;
46   if (failed(parser.parseAttribute(dict)))
47     return {};
48   if (failed(parser.parseGreater()))
49     return {};
50   // Process the data from the parsed dictionary value into struct-like data.
51   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
52   AffineMap map = {};
53   unsigned ptr = 0;
54   unsigned ind = 0;
55   for (const NamedAttribute &attr : dict) {
56     if (attr.first == "dimLevelType") {
57       auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
58       if (!arrayAttr) {
59         parser.emitError(parser.getNameLoc(),
60                          "expected an array for dimension level types");
61         return {};
62       }
63       for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) {
64         auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
65         if (!strAttr) {
66           parser.emitError(parser.getNameLoc(),
67                            "expected a string value in dimension level types");
68           return {};
69         }
70         auto strVal = strAttr.getValue();
71         if (strVal == "dense") {
72           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
73         } else if (strVal == "compressed") {
74           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
75         } else if (strVal == "singleton") {
76           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
77         } else {
78           parser.emitError(parser.getNameLoc(),
79                            "unexpected dimension level type: ")
80               << strVal;
81           return {};
82         }
83       }
84     } else if (attr.first == "dimOrdering") {
85       auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
86       if (!affineAttr) {
87         parser.emitError(parser.getNameLoc(),
88                          "expected an affine map for dimension ordering");
89         return {};
90       }
91       map = affineAttr.getValue();
92     } else if (attr.first == "pointerBitWidth") {
93       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
94       if (!intAttr) {
95         parser.emitError(parser.getNameLoc(),
96                          "expected an integral pointer bitwidth");
97         return {};
98       }
99       ptr = intAttr.getInt();
100     } else if (attr.first == "indexBitWidth") {
101       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
102       if (!intAttr) {
103         parser.emitError(parser.getNameLoc(),
104                          "expected an integral index bitwidth");
105         return {};
106       }
107       ind = intAttr.getInt();
108     } else {
109       parser.emitError(parser.getNameLoc(), "unexpected key: ")
110           << attr.first.str();
111       return {};
112     }
113   }
114   // Construct struct-like storage for attribute.
115   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
116                                                      map, ptr, ind);
117 }
118 
119 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
120   // Print the struct-like storage in dictionary fashion.
121   printer << "encoding<{ dimLevelType = [ ";
122   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
123     switch (getDimLevelType()[i]) {
124     case DimLevelType::Dense:
125       printer << "\"dense\"";
126       break;
127     case DimLevelType::Compressed:
128       printer << "\"compressed\"";
129       break;
130     case DimLevelType::Singleton:
131       printer << "\"singleton\"";
132       break;
133     }
134     if (i != e - 1)
135       printer << ", ";
136   }
137   printer << " ]";
138   if (getDimOrdering())
139     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
140   printer << ", pointerBitWidth = " << getPointerBitWidth()
141           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
142 }
143 
144 LogicalResult SparseTensorEncodingAttr::verify(
145     function_ref<InFlightDiagnostic()> emitError,
146     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
147     unsigned pointerBitWidth, unsigned indexBitWidth) {
148   if (!acceptBitWidth(pointerBitWidth))
149     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
150   if (!acceptBitWidth(indexBitWidth))
151     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
152   if (dimOrdering) {
153     if (!dimOrdering.isPermutation())
154       return emitError()
155              << "expected a permutation affine map for dimension ordering";
156     if (dimOrdering.getNumResults() != dimLevelType.size())
157       return emitError() << "unexpected mismatch in ordering and dimension "
158                             "level types size";
159   }
160   return success();
161 }
162 
163 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
164     ArrayRef<int64_t> shape, Type elementType,
165     function_ref<InFlightDiagnostic()> emitError) const {
166   // Check structural integrity.
167   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
168                     getPointerBitWidth(), getIndexBitWidth())))
169     return failure();
170   // Check integrity with tensor type specifics. Dimension ordering is optional,
171   // but we always should have dimension level types for the full rank.
172   unsigned size = shape.size();
173   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
174     return emitError() << "expected an affine map of size " << size
175                        << " for dimension ordering";
176   if (getDimLevelType().size() != size)
177     return emitError() << "expected an array of size " << size
178                        << " for dimension level types";
179   return success();
180 }
181 
182 SparseTensorEncodingAttr
183 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
184   if (auto ttp = type.dyn_cast<RankedTensorType>())
185     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
186   return nullptr;
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // TensorDialect Operations.
191 //===----------------------------------------------------------------------===//
192 
193 static LogicalResult isInBounds(Value dim, Value tensor) {
194   if (auto constantOp = dim.getDefiningOp<ConstantOp>()) {
195     unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt();
196     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
197       return failure();
198   }
199   return success(); // in bounds, or symbolic
200 }
201 
202 static LogicalResult isMatchingWidth(Value result, unsigned width) {
203   Type etp = result.getType().cast<MemRefType>().getElementType();
204   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
205     return success();
206   return failure();
207 }
208 
209 static LogicalResult verify(NewOp op) {
210   if (!getSparseTensorEncoding(op.result().getType()))
211     return op.emitError("expected a sparse tensor result");
212   return success();
213 }
214 
215 static LogicalResult verify(ConvertOp op) {
216   if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
217     if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
218       assert(tp1.getRank() == tp2.getRank());
219       auto shape1 = tp1.getShape();
220       auto shape2 = tp2.getShape();
221       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
222         if (shape1[d] != shape2[d])
223           return op.emitError()
224                  << "unexpected conversion mismatch in dimension " << d;
225       }
226       return success();
227     }
228   }
229   return op.emitError("unexpected type in convert");
230 }
231 
232 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
233   if (getType() == source().getType())
234     return source();
235   return {};
236 }
237 
238 static LogicalResult verify(ToPointersOp op) {
239   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
240     if (failed(isInBounds(op.dim(), op.tensor())))
241       return op.emitError("requested pointers dimension out of bounds");
242     if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
243       return op.emitError("unexpected type for pointers");
244     return success();
245   }
246   return op.emitError("expected a sparse tensor to get pointers");
247 }
248 
249 static LogicalResult verify(ToIndicesOp op) {
250   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
251     if (failed(isInBounds(op.dim(), op.tensor())))
252       return op.emitError("requested indices dimension out of bounds");
253     if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
254       return op.emitError("unexpected type for indices");
255     return success();
256   }
257   return op.emitError("expected a sparse tensor to get indices");
258 }
259 
260 static LogicalResult verify(ToValuesOp op) {
261   if (!getSparseTensorEncoding(op.tensor().getType()))
262     return op.emitError("expected a sparse tensor to get values");
263   RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
264   MemRefType mtp = op.result().getType().cast<MemRefType>();
265   if (ttp.getElementType() != mtp.getElementType())
266     return op.emitError("unexpected mismatch in element types");
267   return success();
268 }
269 
270 static LogicalResult verify(ToTensorOp op) {
271   if (!getSparseTensorEncoding(op.result().getType()))
272     return op.emitError("expected a sparse tensor as result");
273   return success();
274 }
275 
276 //===----------------------------------------------------------------------===//
277 // TensorDialect Methods.
278 //===----------------------------------------------------------------------===//
279 
280 void SparseTensorDialect::initialize() {
281   addAttributes<
282 #define GET_ATTRDEF_LIST
283 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
284       >();
285   addOperations<
286 #define GET_OP_LIST
287 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
288       >();
289 }
290 
291 #define GET_OP_CLASSES
292 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
293 
294 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser,
295                                               Type type) const {
296   StringRef attrTag;
297   if (failed(parser.parseKeyword(&attrTag)))
298     return Attribute();
299   Attribute attr;
300   auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
301   if (parseResult.hasValue())
302     return attr;
303   parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute");
304   return Attribute();
305 }
306 
307 void SparseTensorDialect::printAttribute(Attribute attr,
308                                          DialectAsmPrinter &printer) const {
309   if (succeeded(generatedAttributePrinter(attr, printer)))
310     return;
311 }
312