1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10 
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "llvm/ADT/TypeSwitch.h"
16 
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19 
20 //===----------------------------------------------------------------------===//
21 // TensorDialect Attribute Methods.
22 //===----------------------------------------------------------------------===//
23 
24 #define GET_ATTRDEF_CLASSES
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
26 
acceptBitWidth(unsigned bitWidth)27 static bool acceptBitWidth(unsigned bitWidth) {
28   switch (bitWidth) {
29   case 0:
30   case 8:
31   case 16:
32   case 32:
33   case 64:
34     return true;
35   default:
36     return false;
37   }
38 }
39 
parse(AsmParser & parser,Type type)40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
41   if (failed(parser.parseLess()))
42     return {};
43   // Parse the data as a dictionary.
44   DictionaryAttr dict;
45   if (failed(parser.parseAttribute(dict)))
46     return {};
47   if (failed(parser.parseGreater()))
48     return {};
49   // Process the data from the parsed dictionary value into struct-like data.
50   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
51   AffineMap map = {};
52   unsigned ptr = 0;
53   unsigned ind = 0;
54   for (const NamedAttribute &attr : dict) {
55     if (attr.getName() == "dimLevelType") {
56       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
57       if (!arrayAttr) {
58         parser.emitError(parser.getNameLoc(),
59                          "expected an array for dimension level types");
60         return {};
61       }
62       for (auto i : arrayAttr) {
63         auto strAttr = i.dyn_cast<StringAttr>();
64         if (!strAttr) {
65           parser.emitError(parser.getNameLoc(),
66                            "expected a string value in dimension level types");
67           return {};
68         }
69         auto strVal = strAttr.getValue();
70         if (strVal == "dense") {
71           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
72         } else if (strVal == "compressed") {
73           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
74         } else if (strVal == "singleton") {
75           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
76         } else {
77           parser.emitError(parser.getNameLoc(),
78                            "unexpected dimension level type: ")
79               << strVal;
80           return {};
81         }
82       }
83     } else if (attr.getName() == "dimOrdering") {
84       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
85       if (!affineAttr) {
86         parser.emitError(parser.getNameLoc(),
87                          "expected an affine map for dimension ordering");
88         return {};
89       }
90       map = affineAttr.getValue();
91     } else if (attr.getName() == "pointerBitWidth") {
92       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
93       if (!intAttr) {
94         parser.emitError(parser.getNameLoc(),
95                          "expected an integral pointer bitwidth");
96         return {};
97       }
98       ptr = intAttr.getInt();
99     } else if (attr.getName() == "indexBitWidth") {
100       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
101       if (!intAttr) {
102         parser.emitError(parser.getNameLoc(),
103                          "expected an integral index bitwidth");
104         return {};
105       }
106       ind = intAttr.getInt();
107     } else {
108       parser.emitError(parser.getNameLoc(), "unexpected key: ")
109           << attr.getName().strref();
110       return {};
111     }
112   }
113   // Construct struct-like storage for attribute.
114   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
115                                                      map, ptr, ind);
116 }
117 
print(AsmPrinter & printer) const118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
119   // Print the struct-like storage in dictionary fashion.
120   printer << "<{ dimLevelType = [ ";
121   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
122     switch (getDimLevelType()[i]) {
123     case DimLevelType::Dense:
124       printer << "\"dense\"";
125       break;
126     case DimLevelType::Compressed:
127       printer << "\"compressed\"";
128       break;
129     case DimLevelType::Singleton:
130       printer << "\"singleton\"";
131       break;
132     }
133     if (i != e - 1)
134       printer << ", ";
135   }
136   printer << " ]";
137   if (getDimOrdering())
138     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
139   printer << ", pointerBitWidth = " << getPointerBitWidth()
140           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
141 }
142 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<DimLevelType> dimLevelType,AffineMap dimOrdering,unsigned pointerBitWidth,unsigned indexBitWidth)143 LogicalResult SparseTensorEncodingAttr::verify(
144     function_ref<InFlightDiagnostic()> emitError,
145     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
146     unsigned pointerBitWidth, unsigned indexBitWidth) {
147   if (!acceptBitWidth(pointerBitWidth))
148     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
149   if (!acceptBitWidth(indexBitWidth))
150     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
151   if (dimOrdering) {
152     if (!dimOrdering.isPermutation())
153       return emitError()
154              << "expected a permutation affine map for dimension ordering";
155     if (dimOrdering.getNumResults() != dimLevelType.size())
156       return emitError() << "unexpected mismatch in ordering and dimension "
157                             "level types size";
158   }
159   return success();
160 }
161 
verifyEncoding(ArrayRef<int64_t> shape,Type elementType,function_ref<InFlightDiagnostic ()> emitError) const162 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
163     ArrayRef<int64_t> shape, Type elementType,
164     function_ref<InFlightDiagnostic()> emitError) const {
165   // Check structural integrity.
166   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
167                     getPointerBitWidth(), getIndexBitWidth())))
168     return failure();
169   // Check integrity with tensor type specifics. Dimension ordering is optional,
170   // but we always should have dimension level types for the full rank.
171   unsigned size = shape.size();
172   if (size == 0)
173     return emitError() << "expected non-scalar sparse tensor";
174   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
175     return emitError() << "expected an affine map of size " << size
176                        << " for dimension ordering";
177   if (getDimLevelType().size() != size)
178     return emitError() << "expected an array of size " << size
179                        << " for dimension level types";
180   return success();
181 }
182 
183 SparseTensorEncodingAttr
getSparseTensorEncoding(Type type)184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
185   if (auto ttp = type.dyn_cast<RankedTensorType>())
186     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
187   return nullptr;
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // TensorDialect Operations.
192 //===----------------------------------------------------------------------===//
193 
isInBounds(Value dim,Value tensor)194 static LogicalResult isInBounds(Value dim, Value tensor) {
195   IntegerAttr constantAttr;
196   if (matchPattern(dim, m_Constant(&constantAttr))) {
197     unsigned d = constantAttr.getInt();
198     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
199       return failure();
200   }
201   return success(); // in bounds, or symbolic
202 }
203 
isMatchingWidth(Value result,unsigned width)204 static LogicalResult isMatchingWidth(Value result, unsigned width) {
205   Type etp = result.getType().cast<MemRefType>().getElementType();
206   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
207     return success();
208   return failure();
209 }
210 
verify()211 LogicalResult ConvertOp::verify() {
212   if (auto tp1 = getSource().getType().dyn_cast<RankedTensorType>()) {
213     if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) {
214       if (tp1.getRank() != tp2.getRank())
215         return emitError("unexpected conversion mismatch in rank");
216       auto shape1 = tp1.getShape();
217       auto shape2 = tp2.getShape();
218       // Accept size matches between the source and the destination type
219       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
220       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
221       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
222         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
223           return emitError("unexpected conversion mismatch in dimension ") << d;
224       return success();
225     }
226   }
227   return emitError("unexpected type in convert");
228 }
229 
fold(ArrayRef<Attribute> operands)230 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
231   if (getType() == getSource().getType())
232     return getSource();
233   return {};
234 }
235 
verify()236 LogicalResult ToPointersOp::verify() {
237   auto e = getSparseTensorEncoding(getTensor().getType());
238   if (failed(isInBounds(getDim(), getTensor())))
239     return emitError("requested pointers dimension out of bounds");
240   if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth())))
241     return emitError("unexpected type for pointers");
242   return success();
243 }
244 
verify()245 LogicalResult ToIndicesOp::verify() {
246   auto e = getSparseTensorEncoding(getTensor().getType());
247   if (failed(isInBounds(getDim(), getTensor())))
248     return emitError("requested indices dimension out of bounds");
249   if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth())))
250     return emitError("unexpected type for indices");
251   return success();
252 }
253 
verify()254 LogicalResult ToValuesOp::verify() {
255   RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
256   MemRefType mtp = getResult().getType().cast<MemRefType>();
257   if (ttp.getElementType() != mtp.getElementType())
258     return emitError("unexpected mismatch in element types");
259   return success();
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // TensorDialect Linalg.Generic Operations.
264 //===----------------------------------------------------------------------===//
265 
266 template <class T>
verifyNumBlockArgs(T * op,Region & region,const char * regionName,TypeRange inputTypes,Type outputType)267 static LogicalResult verifyNumBlockArgs(T *op, Region &region,
268                                         const char *regionName,
269                                         TypeRange inputTypes, Type outputType) {
270   unsigned numArgs = region.getNumArguments();
271   unsigned expectedNum = inputTypes.size();
272   if (numArgs != expectedNum)
273     return op->emitError() << regionName << " region must have exactly "
274                            << expectedNum << " arguments";
275 
276   for (unsigned i = 0; i < numArgs; i++) {
277     Type typ = region.getArgument(i).getType();
278     if (typ != inputTypes[i])
279       return op->emitError() << regionName << " region argument " << (i + 1)
280                              << " type mismatch";
281   }
282   Operation *term = region.front().getTerminator();
283   YieldOp yield = dyn_cast<YieldOp>(term);
284   if (!yield)
285     return op->emitError() << regionName
286                            << " region must end with sparse_tensor.yield";
287   if (yield.getOperand().getType() != outputType)
288     return op->emitError() << regionName << " region yield type mismatch";
289 
290   return success();
291 }
292 
verify()293 LogicalResult BinaryOp::verify() {
294   NamedAttrList attrs = (*this)->getAttrs();
295   Type leftType = getX().getType();
296   Type rightType = getY().getType();
297   Type outputType = getOutput().getType();
298   Region &overlap = getOverlapRegion();
299   Region &left = getLeftRegion();
300   Region &right = getRightRegion();
301 
302   // Check correct number of block arguments and return type for each
303   // non-empty region.
304   LogicalResult regionResult = success();
305   if (!overlap.empty()) {
306     regionResult = verifyNumBlockArgs(
307         this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
308     if (failed(regionResult))
309       return regionResult;
310   }
311   if (!left.empty()) {
312     regionResult =
313         verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
314     if (failed(regionResult))
315       return regionResult;
316   } else if (getLeftIdentity()) {
317     if (leftType != outputType)
318       return emitError("left=identity requires first argument to have the same "
319                        "type as the output");
320   }
321   if (!right.empty()) {
322     regionResult = verifyNumBlockArgs(this, right, "right",
323                                       TypeRange{rightType}, outputType);
324     if (failed(regionResult))
325       return regionResult;
326   } else if (getRightIdentity()) {
327     if (rightType != outputType)
328       return emitError("right=identity requires second argument to have the "
329                        "same type as the output");
330   }
331 
332   return success();
333 }
334 
verify()335 LogicalResult UnaryOp::verify() {
336   Type inputType = getX().getType();
337   Type outputType = getOutput().getType();
338   LogicalResult regionResult = success();
339 
340   // Check correct number of block arguments and return type for each
341   // non-empty region.
342   Region &present = getPresentRegion();
343   if (!present.empty()) {
344     regionResult = verifyNumBlockArgs(this, present, "present",
345                                       TypeRange{inputType}, outputType);
346     if (failed(regionResult))
347       return regionResult;
348   }
349   Region &absent = getAbsentRegion();
350   if (!absent.empty()) {
351     regionResult =
352         verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
353     if (failed(regionResult))
354       return regionResult;
355   }
356 
357   return success();
358 }
359 
verify()360 LogicalResult ReduceOp::verify() {
361   Type inputType = getX().getType();
362   LogicalResult regionResult = success();
363 
364   // Check correct number of block arguments and return type.
365   Region &formula = getRegion();
366   if (!formula.empty()) {
367     regionResult = verifyNumBlockArgs(
368         this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
369     if (failed(regionResult))
370       return regionResult;
371   }
372 
373   return success();
374 }
375 
verify()376 LogicalResult YieldOp::verify() {
377   // Check for compatible parent.
378   auto *parentOp = (*this)->getParentOp();
379   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
380       isa<ReduceOp>(parentOp))
381     return success();
382 
383   return emitOpError(
384       "expected parent op to be sparse_tensor unary, binary, or reduce");
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // TensorDialect Methods.
389 //===----------------------------------------------------------------------===//
390 
initialize()391 void SparseTensorDialect::initialize() {
392   addAttributes<
393 #define GET_ATTRDEF_LIST
394 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
395       >();
396   addOperations<
397 #define GET_OP_LIST
398 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
399       >();
400 }
401 
402 #define GET_OP_CLASSES
403 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
404 
405 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
406