1 //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
10
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/DialectImplementation.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "llvm/ADT/TypeSwitch.h"
16
17 using namespace mlir;
18 using namespace mlir::sparse_tensor;
19
20 //===----------------------------------------------------------------------===//
21 // TensorDialect Attribute Methods.
22 //===----------------------------------------------------------------------===//
23
24 #define GET_ATTRDEF_CLASSES
25 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
26
acceptBitWidth(unsigned bitWidth)27 static bool acceptBitWidth(unsigned bitWidth) {
28 switch (bitWidth) {
29 case 0:
30 case 8:
31 case 16:
32 case 32:
33 case 64:
34 return true;
35 default:
36 return false;
37 }
38 }
39
parse(AsmParser & parser,Type type)40 Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
41 if (failed(parser.parseLess()))
42 return {};
43 // Parse the data as a dictionary.
44 DictionaryAttr dict;
45 if (failed(parser.parseAttribute(dict)))
46 return {};
47 if (failed(parser.parseGreater()))
48 return {};
49 // Process the data from the parsed dictionary value into struct-like data.
50 SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
51 AffineMap map = {};
52 unsigned ptr = 0;
53 unsigned ind = 0;
54 for (const NamedAttribute &attr : dict) {
55 if (attr.getName() == "dimLevelType") {
56 auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
57 if (!arrayAttr) {
58 parser.emitError(parser.getNameLoc(),
59 "expected an array for dimension level types");
60 return {};
61 }
62 for (auto i : arrayAttr) {
63 auto strAttr = i.dyn_cast<StringAttr>();
64 if (!strAttr) {
65 parser.emitError(parser.getNameLoc(),
66 "expected a string value in dimension level types");
67 return {};
68 }
69 auto strVal = strAttr.getValue();
70 if (strVal == "dense") {
71 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
72 } else if (strVal == "compressed") {
73 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
74 } else if (strVal == "singleton") {
75 dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
76 } else {
77 parser.emitError(parser.getNameLoc(),
78 "unexpected dimension level type: ")
79 << strVal;
80 return {};
81 }
82 }
83 } else if (attr.getName() == "dimOrdering") {
84 auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
85 if (!affineAttr) {
86 parser.emitError(parser.getNameLoc(),
87 "expected an affine map for dimension ordering");
88 return {};
89 }
90 map = affineAttr.getValue();
91 } else if (attr.getName() == "pointerBitWidth") {
92 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
93 if (!intAttr) {
94 parser.emitError(parser.getNameLoc(),
95 "expected an integral pointer bitwidth");
96 return {};
97 }
98 ptr = intAttr.getInt();
99 } else if (attr.getName() == "indexBitWidth") {
100 auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
101 if (!intAttr) {
102 parser.emitError(parser.getNameLoc(),
103 "expected an integral index bitwidth");
104 return {};
105 }
106 ind = intAttr.getInt();
107 } else {
108 parser.emitError(parser.getNameLoc(), "unexpected key: ")
109 << attr.getName().strref();
110 return {};
111 }
112 }
113 // Construct struct-like storage for attribute.
114 return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
115 map, ptr, ind);
116 }
117
print(AsmPrinter & printer) const118 void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
119 // Print the struct-like storage in dictionary fashion.
120 printer << "<{ dimLevelType = [ ";
121 for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
122 switch (getDimLevelType()[i]) {
123 case DimLevelType::Dense:
124 printer << "\"dense\"";
125 break;
126 case DimLevelType::Compressed:
127 printer << "\"compressed\"";
128 break;
129 case DimLevelType::Singleton:
130 printer << "\"singleton\"";
131 break;
132 }
133 if (i != e - 1)
134 printer << ", ";
135 }
136 printer << " ]";
137 if (getDimOrdering())
138 printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
139 printer << ", pointerBitWidth = " << getPointerBitWidth()
140 << ", indexBitWidth = " << getIndexBitWidth() << " }>";
141 }
142
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<DimLevelType> dimLevelType,AffineMap dimOrdering,unsigned pointerBitWidth,unsigned indexBitWidth)143 LogicalResult SparseTensorEncodingAttr::verify(
144 function_ref<InFlightDiagnostic()> emitError,
145 ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
146 unsigned pointerBitWidth, unsigned indexBitWidth) {
147 if (!acceptBitWidth(pointerBitWidth))
148 return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
149 if (!acceptBitWidth(indexBitWidth))
150 return emitError() << "unexpected index bitwidth: " << indexBitWidth;
151 if (dimOrdering) {
152 if (!dimOrdering.isPermutation())
153 return emitError()
154 << "expected a permutation affine map for dimension ordering";
155 if (dimOrdering.getNumResults() != dimLevelType.size())
156 return emitError() << "unexpected mismatch in ordering and dimension "
157 "level types size";
158 }
159 return success();
160 }
161
verifyEncoding(ArrayRef<int64_t> shape,Type elementType,function_ref<InFlightDiagnostic ()> emitError) const162 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
163 ArrayRef<int64_t> shape, Type elementType,
164 function_ref<InFlightDiagnostic()> emitError) const {
165 // Check structural integrity.
166 if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
167 getPointerBitWidth(), getIndexBitWidth())))
168 return failure();
169 // Check integrity with tensor type specifics. Dimension ordering is optional,
170 // but we always should have dimension level types for the full rank.
171 unsigned size = shape.size();
172 if (size == 0)
173 return emitError() << "expected non-scalar sparse tensor";
174 if (getDimOrdering() && getDimOrdering().getNumResults() != size)
175 return emitError() << "expected an affine map of size " << size
176 << " for dimension ordering";
177 if (getDimLevelType().size() != size)
178 return emitError() << "expected an array of size " << size
179 << " for dimension level types";
180 return success();
181 }
182
183 SparseTensorEncodingAttr
getSparseTensorEncoding(Type type)184 mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
185 if (auto ttp = type.dyn_cast<RankedTensorType>())
186 return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
187 return nullptr;
188 }
189
190 //===----------------------------------------------------------------------===//
191 // TensorDialect Operations.
192 //===----------------------------------------------------------------------===//
193
isInBounds(Value dim,Value tensor)194 static LogicalResult isInBounds(Value dim, Value tensor) {
195 IntegerAttr constantAttr;
196 if (matchPattern(dim, m_Constant(&constantAttr))) {
197 unsigned d = constantAttr.getInt();
198 if (d >= tensor.getType().cast<RankedTensorType>().getRank())
199 return failure();
200 }
201 return success(); // in bounds, or symbolic
202 }
203
isMatchingWidth(Value result,unsigned width)204 static LogicalResult isMatchingWidth(Value result, unsigned width) {
205 Type etp = result.getType().cast<MemRefType>().getElementType();
206 if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
207 return success();
208 return failure();
209 }
210
verify()211 LogicalResult ConvertOp::verify() {
212 if (auto tp1 = getSource().getType().dyn_cast<RankedTensorType>()) {
213 if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) {
214 if (tp1.getRank() != tp2.getRank())
215 return emitError("unexpected conversion mismatch in rank");
216 auto shape1 = tp1.getShape();
217 auto shape2 = tp2.getShape();
218 // Accept size matches between the source and the destination type
219 // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
220 // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
221 for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
222 if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
223 return emitError("unexpected conversion mismatch in dimension ") << d;
224 return success();
225 }
226 }
227 return emitError("unexpected type in convert");
228 }
229
fold(ArrayRef<Attribute> operands)230 OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
231 if (getType() == getSource().getType())
232 return getSource();
233 return {};
234 }
235
verify()236 LogicalResult ToPointersOp::verify() {
237 auto e = getSparseTensorEncoding(getTensor().getType());
238 if (failed(isInBounds(getDim(), getTensor())))
239 return emitError("requested pointers dimension out of bounds");
240 if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth())))
241 return emitError("unexpected type for pointers");
242 return success();
243 }
244
verify()245 LogicalResult ToIndicesOp::verify() {
246 auto e = getSparseTensorEncoding(getTensor().getType());
247 if (failed(isInBounds(getDim(), getTensor())))
248 return emitError("requested indices dimension out of bounds");
249 if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth())))
250 return emitError("unexpected type for indices");
251 return success();
252 }
253
verify()254 LogicalResult ToValuesOp::verify() {
255 RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
256 MemRefType mtp = getResult().getType().cast<MemRefType>();
257 if (ttp.getElementType() != mtp.getElementType())
258 return emitError("unexpected mismatch in element types");
259 return success();
260 }
261
262 //===----------------------------------------------------------------------===//
263 // TensorDialect Linalg.Generic Operations.
264 //===----------------------------------------------------------------------===//
265
266 template <class T>
verifyNumBlockArgs(T * op,Region & region,const char * regionName,TypeRange inputTypes,Type outputType)267 static LogicalResult verifyNumBlockArgs(T *op, Region ®ion,
268 const char *regionName,
269 TypeRange inputTypes, Type outputType) {
270 unsigned numArgs = region.getNumArguments();
271 unsigned expectedNum = inputTypes.size();
272 if (numArgs != expectedNum)
273 return op->emitError() << regionName << " region must have exactly "
274 << expectedNum << " arguments";
275
276 for (unsigned i = 0; i < numArgs; i++) {
277 Type typ = region.getArgument(i).getType();
278 if (typ != inputTypes[i])
279 return op->emitError() << regionName << " region argument " << (i + 1)
280 << " type mismatch";
281 }
282 Operation *term = region.front().getTerminator();
283 YieldOp yield = dyn_cast<YieldOp>(term);
284 if (!yield)
285 return op->emitError() << regionName
286 << " region must end with sparse_tensor.yield";
287 if (yield.getOperand().getType() != outputType)
288 return op->emitError() << regionName << " region yield type mismatch";
289
290 return success();
291 }
292
verify()293 LogicalResult BinaryOp::verify() {
294 NamedAttrList attrs = (*this)->getAttrs();
295 Type leftType = getX().getType();
296 Type rightType = getY().getType();
297 Type outputType = getOutput().getType();
298 Region &overlap = getOverlapRegion();
299 Region &left = getLeftRegion();
300 Region &right = getRightRegion();
301
302 // Check correct number of block arguments and return type for each
303 // non-empty region.
304 LogicalResult regionResult = success();
305 if (!overlap.empty()) {
306 regionResult = verifyNumBlockArgs(
307 this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
308 if (failed(regionResult))
309 return regionResult;
310 }
311 if (!left.empty()) {
312 regionResult =
313 verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
314 if (failed(regionResult))
315 return regionResult;
316 } else if (getLeftIdentity()) {
317 if (leftType != outputType)
318 return emitError("left=identity requires first argument to have the same "
319 "type as the output");
320 }
321 if (!right.empty()) {
322 regionResult = verifyNumBlockArgs(this, right, "right",
323 TypeRange{rightType}, outputType);
324 if (failed(regionResult))
325 return regionResult;
326 } else if (getRightIdentity()) {
327 if (rightType != outputType)
328 return emitError("right=identity requires second argument to have the "
329 "same type as the output");
330 }
331
332 return success();
333 }
334
verify()335 LogicalResult UnaryOp::verify() {
336 Type inputType = getX().getType();
337 Type outputType = getOutput().getType();
338 LogicalResult regionResult = success();
339
340 // Check correct number of block arguments and return type for each
341 // non-empty region.
342 Region &present = getPresentRegion();
343 if (!present.empty()) {
344 regionResult = verifyNumBlockArgs(this, present, "present",
345 TypeRange{inputType}, outputType);
346 if (failed(regionResult))
347 return regionResult;
348 }
349 Region &absent = getAbsentRegion();
350 if (!absent.empty()) {
351 regionResult =
352 verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
353 if (failed(regionResult))
354 return regionResult;
355 }
356
357 return success();
358 }
359
verify()360 LogicalResult ReduceOp::verify() {
361 Type inputType = getX().getType();
362 LogicalResult regionResult = success();
363
364 // Check correct number of block arguments and return type.
365 Region &formula = getRegion();
366 if (!formula.empty()) {
367 regionResult = verifyNumBlockArgs(
368 this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
369 if (failed(regionResult))
370 return regionResult;
371 }
372
373 return success();
374 }
375
verify()376 LogicalResult YieldOp::verify() {
377 // Check for compatible parent.
378 auto *parentOp = (*this)->getParentOp();
379 if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
380 isa<ReduceOp>(parentOp))
381 return success();
382
383 return emitOpError(
384 "expected parent op to be sparse_tensor unary, binary, or reduce");
385 }
386
387 //===----------------------------------------------------------------------===//
388 // TensorDialect Methods.
389 //===----------------------------------------------------------------------===//
390
initialize()391 void SparseTensorDialect::initialize() {
392 addAttributes<
393 #define GET_ATTRDEF_LIST
394 #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
395 >();
396 addOperations<
397 #define GET_OP_LIST
398 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
399 >();
400 }
401
402 #define GET_OP_CLASSES
403 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
404
405 #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
406