1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19 
20 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
21 
22 //===----------------------------------------------------------------------===//
23 // TensorDialect Attribute Methods.
24 //===----------------------------------------------------------------------===//
25 
26 #define GET_ATTRDEF_CLASSES
27 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
28 
29 static bool acceptBitWidth(unsigned bitWidth) {
30   switch (bitWidth) {
31   case 0:
32   case 8:
33   case 16:
34   case 32:
35   case 64:
36     return true;
37   default:
38     return false;
39   }
40 }
41 
42 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
43   if (failed(parser.parseLess()))
44     return {};
45   // Parse the data as a dictionary.
46   DictionaryAttr dict;
47   if (failed(parser.parseAttribute(dict)))
48     return {};
49   if (failed(parser.parseGreater()))
50     return {};
51   // Process the data from the parsed dictionary value into struct-like data.
52   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
53   AffineMap map = {};
54   unsigned ptr = 0;
55   unsigned ind = 0;
56   for (const NamedAttribute &attr : dict) {
57     if (attr.getName() == "dimLevelType") {
58       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
59       if (!arrayAttr) {
60         parser.emitError(parser.getNameLoc(),
61                          "expected an array for dimension level types");
62         return {};
63       }
64       for (unsigned i = 0, e = arrayAttr.size(); i < e; i++) {
65         auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
66         if (!strAttr) {
67           parser.emitError(parser.getNameLoc(),
68                            "expected a string value in dimension level types");
69           return {};
70         }
71         auto strVal = strAttr.getValue();
72         if (strVal == "dense") {
73           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
74         } else if (strVal == "compressed") {
75           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
76         } else if (strVal == "singleton") {
77           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
78         } else {
79           parser.emitError(parser.getNameLoc(),
80                            "unexpected dimension level type: ")
81               << strVal;
82           return {};
83         }
84       }
85     } else if (attr.getName() == "dimOrdering") {
86       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
87       if (!affineAttr) {
88         parser.emitError(parser.getNameLoc(),
89                          "expected an affine map for dimension ordering");
90         return {};
91       }
92       map = affineAttr.getValue();
93     } else if (attr.getName() == "pointerBitWidth") {
94       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
95       if (!intAttr) {
96         parser.emitError(parser.getNameLoc(),
97                          "expected an integral pointer bitwidth");
98         return {};
99       }
100       ptr = intAttr.getInt();
101     } else if (attr.getName() == "indexBitWidth") {
102       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
103       if (!intAttr) {
104         parser.emitError(parser.getNameLoc(),
105                          "expected an integral index bitwidth");
106         return {};
107       }
108       ind = intAttr.getInt();
109     } else {
110       parser.emitError(parser.getNameLoc(), "unexpected key: ")
111           << attr.getName().strref();
112       return {};
113     }
114   }
115   // Construct struct-like storage for attribute.
116   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
117                                                      map, ptr, ind);
118 }
119 
120 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
121   // Print the struct-like storage in dictionary fashion.
122   printer << "<{ dimLevelType = [ ";
123   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
124     switch (getDimLevelType()[i]) {
125     case DimLevelType::Dense:
126       printer << "\"dense\"";
127       break;
128     case DimLevelType::Compressed:
129       printer << "\"compressed\"";
130       break;
131     case DimLevelType::Singleton:
132       printer << "\"singleton\"";
133       break;
134     }
135     if (i != e - 1)
136       printer << ", ";
137   }
138   printer << " ]";
139   if (getDimOrdering())
140     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
141   printer << ", pointerBitWidth = " << getPointerBitWidth()
142           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
143 }
144 
145 LogicalResult SparseTensorEncodingAttr::verify(
146     function_ref<InFlightDiagnostic()> emitError,
147     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
148     unsigned pointerBitWidth, unsigned indexBitWidth) {
149   if (!acceptBitWidth(pointerBitWidth))
150     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
151   if (!acceptBitWidth(indexBitWidth))
152     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
153   if (dimOrdering) {
154     if (!dimOrdering.isPermutation())
155       return emitError()
156              << "expected a permutation affine map for dimension ordering";
157     if (dimOrdering.getNumResults() != dimLevelType.size())
158       return emitError() << "unexpected mismatch in ordering and dimension "
159                             "level types size";
160   }
161   return success();
162 }
163 
164 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
165     ArrayRef<int64_t> shape, Type elementType,
166     function_ref<InFlightDiagnostic()> emitError) const {
167   // Check structural integrity.
168   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
169                     getPointerBitWidth(), getIndexBitWidth())))
170     return failure();
171   // Check integrity with tensor type specifics. Dimension ordering is optional,
172   // but we always should have dimension level types for the full rank.
173   unsigned size = shape.size();
174   if (size == 0)
175     return emitError() << "expected non-scalar sparse tensor";
176   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
177     return emitError() << "expected an affine map of size " << size
178                        << " for dimension ordering";
179   if (getDimLevelType().size() != size)
180     return emitError() << "expected an array of size " << size
181                        << " for dimension level types";
182   return success();
183 }
184 
185 SparseTensorEncodingAttr
186 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
187   if (auto ttp = type.dyn_cast<RankedTensorType>())
188     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
189   return nullptr;
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // TensorDialect Operations.
194 //===----------------------------------------------------------------------===//
195 
196 static LogicalResult isInBounds(Value dim, Value tensor) {
197   if (auto constantOp = dim.getDefiningOp<arith::ConstantOp>()) {
198     unsigned d = constantOp.getValue().cast<IntegerAttr>().getInt();
199     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
200       return failure();
201   }
202   return success(); // in bounds, or symbolic
203 }
204 
205 static LogicalResult isMatchingWidth(Value result, unsigned width) {
206   Type etp = result.getType().cast<MemRefType>().getElementType();
207   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
208     return success();
209   return failure();
210 }
211 
212 static LogicalResult verify(NewOp op) {
213   if (!getSparseTensorEncoding(op.result().getType()))
214     return op.emitError("expected a sparse tensor result");
215   return success();
216 }
217 
218 static LogicalResult verify(InitOp op) {
219   if (!getSparseTensorEncoding(op.result().getType()))
220     return op.emitError("expected a sparse tensor result");
221   RankedTensorType ttp = op.getType().cast<RankedTensorType>();
222   unsigned rank = ttp.getRank();
223   if (rank != op.sizes().size())
224     return op.emitError("unexpected mismatch between tensor rank and sizes: ")
225            << rank << " vs. " << op.sizes().size();
226   auto shape = ttp.getShape();
227   for (unsigned i = 0; i < rank; i++) {
228     if (shape[i] == ShapedType::kDynamicSize)
229       continue;
230     auto constantOp = op.sizes()[i].getDefiningOp<arith::ConstantOp>();
231     if (!constantOp ||
232         constantOp.getValue().cast<IntegerAttr>().getInt() != shape[i])
233       return op.emitError("unexpected mismatch with static dimension size ")
234              << shape[i];
235   }
236   return success();
237 }
238 
239 static LogicalResult verify(ConvertOp op) {
240   if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
241     if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
242       if (tp1.getRank() != tp2.getRank())
243         return op.emitError("unexpected conversion mismatch in rank");
244       auto shape1 = tp1.getShape();
245       auto shape2 = tp2.getShape();
246       // Accept size matches between the source and the destination type
247       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
248       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
249       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
250         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
251           return op.emitError("unexpected conversion mismatch in dimension ")
252                  << d;
253       }
254       return success();
255     }
256   }
257   return op.emitError("unexpected type in convert");
258 }
259 
260 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
261   if (getType() == source().getType())
262     return source();
263   return {};
264 }
265 
266 static LogicalResult verify(ToPointersOp op) {
267   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
268     if (failed(isInBounds(op.dim(), op.tensor())))
269       return op.emitError("requested pointers dimension out of bounds");
270     if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
271       return op.emitError("unexpected type for pointers");
272     return success();
273   }
274   return op.emitError("expected a sparse tensor to get pointers");
275 }
276 
277 static LogicalResult verify(ToIndicesOp op) {
278   if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
279     if (failed(isInBounds(op.dim(), op.tensor())))
280       return op.emitError("requested indices dimension out of bounds");
281     if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
282       return op.emitError("unexpected type for indices");
283     return success();
284   }
285   return op.emitError("expected a sparse tensor to get indices");
286 }
287 
288 static LogicalResult verify(ToValuesOp op) {
289   if (!getSparseTensorEncoding(op.tensor().getType()))
290     return op.emitError("expected a sparse tensor to get values");
291   RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
292   MemRefType mtp = op.result().getType().cast<MemRefType>();
293   if (ttp.getElementType() != mtp.getElementType())
294     return op.emitError("unexpected mismatch in element types");
295   return success();
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // TensorDialect Management Operations.
300 //===----------------------------------------------------------------------===//
301 
302 static LogicalResult verify(LexInsertOp op) {
303   if (!getSparseTensorEncoding(op.tensor().getType()))
304     return op.emitError("expected a sparse tensor for insertion");
305   return success();
306 }
307 
308 static LogicalResult verify(ExpandOp op) {
309   if (!getSparseTensorEncoding(op.tensor().getType()))
310     return op.emitError("expected a sparse tensor for expansion");
311   return success();
312 }
313 
314 static LogicalResult verify(CompressOp op) {
315   if (!getSparseTensorEncoding(op.tensor().getType()))
316     return op.emitError("expected a sparse tensor for compression");
317   return success();
318 }
319 
320 static LogicalResult verify(LoadOp op) {
321   if (!getSparseTensorEncoding(op.tensor().getType()))
322     return op.emitError("expected a sparse tensor to materialize");
323   return success();
324 }
325 
326 static LogicalResult verify(ReleaseOp op) {
327   if (!getSparseTensorEncoding(op.tensor().getType()))
328     return op.emitError("expected a sparse tensor to release");
329   return success();
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // TensorDialect Methods.
334 //===----------------------------------------------------------------------===//
335 
336 void SparseTensorDialect::initialize() {
337   addAttributes<
338 #define GET_ATTRDEF_LIST
339 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
340       >();
341   addOperations<
342 #define GET_OP_LIST
343 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
344       >();
345 }
346 
347 #define GET_OP_CLASSES
348 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
349 
350 Attribute SparseTensorDialect::parseAttribute(DialectAsmParser &parser,
351                                               Type type) const {
352   StringRef attrTag;
353   if (failed(parser.parseKeyword(&attrTag)))
354     return Attribute();
355   Attribute attr;
356   auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
357   if (parseResult.hasValue())
358     return attr;
359   parser.emitError(parser.getNameLoc(), "unknown sparse tensor attribute");
360   return Attribute();
361 }
362 
363 void SparseTensorDialect::printAttribute(Attribute attr,
364                                          DialectAsmPrinter &printer) const {
365   if (succeeded(generatedAttributePrinter(attr, printer)))
366     return;
367 }
368