1319072f4SAart Bik //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===//
2319072f4SAart Bik //
3319072f4SAart Bik // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4319072f4SAart Bik // See https://llvm.org/LICENSE.txt for license information.
5319072f4SAart Bik // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6319072f4SAart Bik //
7319072f4SAart Bik //===----------------------------------------------------------------------===//
8319072f4SAart Bik 
9319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
105517208dSAart Bik 
11319072f4SAart Bik #include "mlir/IR/Builders.h"
120a292199SAart Bik #include "mlir/IR/DialectImplementation.h"
1365e7cd13SRiver Riddle #include "mlir/IR/Matchers.h"
14319072f4SAart Bik #include "mlir/IR/OpImplementation.h"
150a292199SAart Bik #include "llvm/ADT/TypeSwitch.h"
16319072f4SAart Bik 
17319072f4SAart Bik using namespace mlir;
18319072f4SAart Bik using namespace mlir::sparse_tensor;
19319072f4SAart Bik 
200a292199SAart Bik //===----------------------------------------------------------------------===//
2196a23911SAart Bik // TensorDialect Attribute Methods.
220a292199SAart Bik //===----------------------------------------------------------------------===//
230a292199SAart Bik 
240a292199SAart Bik #define GET_ATTRDEF_CLASSES
250a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
260a292199SAart Bik 
acceptBitWidth(unsigned bitWidth)270a292199SAart Bik static bool acceptBitWidth(unsigned bitWidth) {
280a292199SAart Bik   switch (bitWidth) {
290a292199SAart Bik   case 0:
300a292199SAart Bik   case 8:
310a292199SAart Bik   case 16:
320a292199SAart Bik   case 32:
330a292199SAart Bik   case 64:
340a292199SAart Bik     return true;
350a292199SAart Bik   default:
360a292199SAart Bik     return false;
370a292199SAart Bik   }
380a292199SAart Bik }
390a292199SAart Bik 
parse(AsmParser & parser,Type type)40f97e72aaSMehdi Amini Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
410a292199SAart Bik   if (failed(parser.parseLess()))
420a292199SAart Bik     return {};
430a292199SAart Bik   // Parse the data as a dictionary.
440a292199SAart Bik   DictionaryAttr dict;
450a292199SAart Bik   if (failed(parser.parseAttribute(dict)))
460a292199SAart Bik     return {};
470a292199SAart Bik   if (failed(parser.parseGreater()))
480a292199SAart Bik     return {};
490a292199SAart Bik   // Process the data from the parsed dictionary value into struct-like data.
500a292199SAart Bik   SmallVector<SparseTensorEncodingAttr::DimLevelType, 4> dlt;
510a292199SAart Bik   AffineMap map = {};
520a292199SAart Bik   unsigned ptr = 0;
530a292199SAart Bik   unsigned ind = 0;
540a292199SAart Bik   for (const NamedAttribute &attr : dict) {
550c7890c8SRiver Riddle     if (attr.getName() == "dimLevelType") {
560c7890c8SRiver Riddle       auto arrayAttr = attr.getValue().dyn_cast<ArrayAttr>();
570a292199SAart Bik       if (!arrayAttr) {
580a292199SAart Bik         parser.emitError(parser.getNameLoc(),
590a292199SAart Bik                          "expected an array for dimension level types");
600a292199SAart Bik         return {};
610a292199SAart Bik       }
62e5639b3fSMehdi Amini       for (auto i : arrayAttr) {
63e5639b3fSMehdi Amini         auto strAttr = i.dyn_cast<StringAttr>();
640a292199SAart Bik         if (!strAttr) {
650a292199SAart Bik           parser.emitError(parser.getNameLoc(),
660a292199SAart Bik                            "expected a string value in dimension level types");
670a292199SAart Bik           return {};
680a292199SAart Bik         }
690a292199SAart Bik         auto strVal = strAttr.getValue();
700a292199SAart Bik         if (strVal == "dense") {
710a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Dense);
720a292199SAart Bik         } else if (strVal == "compressed") {
730a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Compressed);
740a292199SAart Bik         } else if (strVal == "singleton") {
750a292199SAart Bik           dlt.push_back(SparseTensorEncodingAttr::DimLevelType::Singleton);
760a292199SAart Bik         } else {
770a292199SAart Bik           parser.emitError(parser.getNameLoc(),
780a292199SAart Bik                            "unexpected dimension level type: ")
790a292199SAart Bik               << strVal;
800a292199SAart Bik           return {};
810a292199SAart Bik         }
820a292199SAart Bik       }
830c7890c8SRiver Riddle     } else if (attr.getName() == "dimOrdering") {
840c7890c8SRiver Riddle       auto affineAttr = attr.getValue().dyn_cast<AffineMapAttr>();
850a292199SAart Bik       if (!affineAttr) {
860a292199SAart Bik         parser.emitError(parser.getNameLoc(),
870a292199SAart Bik                          "expected an affine map for dimension ordering");
880a292199SAart Bik         return {};
890a292199SAart Bik       }
900a292199SAart Bik       map = affineAttr.getValue();
910c7890c8SRiver Riddle     } else if (attr.getName() == "pointerBitWidth") {
920c7890c8SRiver Riddle       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
930a292199SAart Bik       if (!intAttr) {
940a292199SAart Bik         parser.emitError(parser.getNameLoc(),
950a292199SAart Bik                          "expected an integral pointer bitwidth");
960a292199SAart Bik         return {};
970a292199SAart Bik       }
980a292199SAart Bik       ptr = intAttr.getInt();
990c7890c8SRiver Riddle     } else if (attr.getName() == "indexBitWidth") {
1000c7890c8SRiver Riddle       auto intAttr = attr.getValue().dyn_cast<IntegerAttr>();
1010a292199SAart Bik       if (!intAttr) {
1020a292199SAart Bik         parser.emitError(parser.getNameLoc(),
1030a292199SAart Bik                          "expected an integral index bitwidth");
1040a292199SAart Bik         return {};
1050a292199SAart Bik       }
1060a292199SAart Bik       ind = intAttr.getInt();
1070a292199SAart Bik     } else {
1080a292199SAart Bik       parser.emitError(parser.getNameLoc(), "unexpected key: ")
1090c7890c8SRiver Riddle           << attr.getName().strref();
1100a292199SAart Bik       return {};
1110a292199SAart Bik     }
1120a292199SAart Bik   }
1130a292199SAart Bik   // Construct struct-like storage for attribute.
114fb093c83SChris Lattner   return parser.getChecked<SparseTensorEncodingAttr>(parser.getContext(), dlt,
115fb093c83SChris Lattner                                                      map, ptr, ind);
1160a292199SAart Bik }
1170a292199SAart Bik 
print(AsmPrinter & printer) const118f97e72aaSMehdi Amini void SparseTensorEncodingAttr::print(AsmPrinter &printer) const {
1190a292199SAart Bik   // Print the struct-like storage in dictionary fashion.
120f30a8a6fSMehdi Amini   printer << "<{ dimLevelType = [ ";
1210a292199SAart Bik   for (unsigned i = 0, e = getDimLevelType().size(); i < e; i++) {
1220a292199SAart Bik     switch (getDimLevelType()[i]) {
1230a292199SAart Bik     case DimLevelType::Dense:
1240a292199SAart Bik       printer << "\"dense\"";
1250a292199SAart Bik       break;
1260a292199SAart Bik     case DimLevelType::Compressed:
1270a292199SAart Bik       printer << "\"compressed\"";
1280a292199SAart Bik       break;
1290a292199SAart Bik     case DimLevelType::Singleton:
1300a292199SAart Bik       printer << "\"singleton\"";
1310a292199SAart Bik       break;
1320a292199SAart Bik     }
1330a292199SAart Bik     if (i != e - 1)
1340a292199SAart Bik       printer << ", ";
1350a292199SAart Bik   }
1360a292199SAart Bik   printer << " ]";
1370a292199SAart Bik   if (getDimOrdering())
1380a292199SAart Bik     printer << ", dimOrdering = affine_map<" << getDimOrdering() << ">";
1390a292199SAart Bik   printer << ", pointerBitWidth = " << getPointerBitWidth()
1400a292199SAart Bik           << ", indexBitWidth = " << getIndexBitWidth() << " }>";
1410a292199SAart Bik }
1420a292199SAart Bik 
verify(function_ref<InFlightDiagnostic ()> emitError,ArrayRef<DimLevelType> dimLevelType,AffineMap dimOrdering,unsigned pointerBitWidth,unsigned indexBitWidth)1430a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verify(
1440a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError,
1450a292199SAart Bik     ArrayRef<DimLevelType> dimLevelType, AffineMap dimOrdering,
1460a292199SAart Bik     unsigned pointerBitWidth, unsigned indexBitWidth) {
1470a292199SAart Bik   if (!acceptBitWidth(pointerBitWidth))
1480a292199SAart Bik     return emitError() << "unexpected pointer bitwidth: " << pointerBitWidth;
1490a292199SAart Bik   if (!acceptBitWidth(indexBitWidth))
1500a292199SAart Bik     return emitError() << "unexpected index bitwidth: " << indexBitWidth;
1510a292199SAart Bik   if (dimOrdering) {
1520a292199SAart Bik     if (!dimOrdering.isPermutation())
1530a292199SAart Bik       return emitError()
1540a292199SAart Bik              << "expected a permutation affine map for dimension ordering";
1550a292199SAart Bik     if (dimOrdering.getNumResults() != dimLevelType.size())
1560a292199SAart Bik       return emitError() << "unexpected mismatch in ordering and dimension "
1570a292199SAart Bik                             "level types size";
1580a292199SAart Bik   }
1590a292199SAart Bik   return success();
1600a292199SAart Bik }
1610a292199SAart Bik 
verifyEncoding(ArrayRef<int64_t> shape,Type elementType,function_ref<InFlightDiagnostic ()> emitError) const1620a292199SAart Bik LogicalResult SparseTensorEncodingAttr::verifyEncoding(
1630a292199SAart Bik     ArrayRef<int64_t> shape, Type elementType,
1640a292199SAart Bik     function_ref<InFlightDiagnostic()> emitError) const {
1650a292199SAart Bik   // Check structural integrity.
1660a292199SAart Bik   if (failed(verify(emitError, getDimLevelType(), getDimOrdering(),
1670a292199SAart Bik                     getPointerBitWidth(), getIndexBitWidth())))
1680a292199SAart Bik     return failure();
1690a292199SAart Bik   // Check integrity with tensor type specifics. Dimension ordering is optional,
1700a292199SAart Bik   // but we always should have dimension level types for the full rank.
1710a292199SAart Bik   unsigned size = shape.size();
1724aa9b398SAart Bik   if (size == 0)
1734aa9b398SAart Bik     return emitError() << "expected non-scalar sparse tensor";
1740a292199SAart Bik   if (getDimOrdering() && getDimOrdering().getNumResults() != size)
1750a292199SAart Bik     return emitError() << "expected an affine map of size " << size
1760a292199SAart Bik                        << " for dimension ordering";
1770a292199SAart Bik   if (getDimLevelType().size() != size)
1780a292199SAart Bik     return emitError() << "expected an array of size " << size
1790a292199SAart Bik                        << " for dimension level types";
1800a292199SAart Bik   return success();
1810a292199SAart Bik }
1820a292199SAart Bik 
18396a23911SAart Bik SparseTensorEncodingAttr
getSparseTensorEncoding(Type type)18496a23911SAart Bik mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
18596a23911SAart Bik   if (auto ttp = type.dyn_cast<RankedTensorType>())
18696a23911SAart Bik     return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
18796a23911SAart Bik   return nullptr;
18896a23911SAart Bik }
18996a23911SAart Bik 
1900a292199SAart Bik //===----------------------------------------------------------------------===//
19196a23911SAart Bik // TensorDialect Operations.
19296a23911SAart Bik //===----------------------------------------------------------------------===//
19396a23911SAart Bik 
isInBounds(Value dim,Value tensor)19496a23911SAart Bik static LogicalResult isInBounds(Value dim, Value tensor) {
19565e7cd13SRiver Riddle   IntegerAttr constantAttr;
19665e7cd13SRiver Riddle   if (matchPattern(dim, m_Constant(&constantAttr))) {
19765e7cd13SRiver Riddle     unsigned d = constantAttr.getInt();
19896a23911SAart Bik     if (d >= tensor.getType().cast<RankedTensorType>().getRank())
19996a23911SAart Bik       return failure();
20096a23911SAart Bik   }
20196a23911SAart Bik   return success(); // in bounds, or symbolic
20296a23911SAart Bik }
20396a23911SAart Bik 
isMatchingWidth(Value result,unsigned width)20496a23911SAart Bik static LogicalResult isMatchingWidth(Value result, unsigned width) {
20596a23911SAart Bik   Type etp = result.getType().cast<MemRefType>().getElementType();
20696a23911SAart Bik   if ((width == 0 && etp.isIndex()) || (width > 0 && etp.isInteger(width)))
20796a23911SAart Bik     return success();
20896a23911SAart Bik   return failure();
20996a23911SAart Bik }
21096a23911SAart Bik 
verify()211b98dc035SRiver Riddle LogicalResult ConvertOp::verify() {
21204235d07SJacques Pienaar   if (auto tp1 = getSource().getType().dyn_cast<RankedTensorType>()) {
21304235d07SJacques Pienaar     if (auto tp2 = getDest().getType().dyn_cast<RankedTensorType>()) {
2141e6ef0cfSAart Bik       if (tp1.getRank() != tp2.getRank())
215b98dc035SRiver Riddle         return emitError("unexpected conversion mismatch in rank");
216697ea09dSAart Bik       auto shape1 = tp1.getShape();
217697ea09dSAart Bik       auto shape2 = tp2.getShape();
2189d1db3d4SAart Bik       // Accept size matches between the source and the destination type
2199d1db3d4SAart Bik       // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
2209d1db3d4SAart Bik       // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
221b98dc035SRiver Riddle       for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
2229d1db3d4SAart Bik         if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
223b98dc035SRiver Riddle           return emitError("unexpected conversion mismatch in dimension ") << d;
224697ea09dSAart Bik       return success();
225697ea09dSAart Bik     }
226697ea09dSAart Bik   }
227b98dc035SRiver Riddle   return emitError("unexpected type in convert");
228697ea09dSAart Bik }
229697ea09dSAart Bik 
fold(ArrayRef<Attribute> operands)230066d786cSAart Bik OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
23104235d07SJacques Pienaar   if (getType() == getSource().getType())
23204235d07SJacques Pienaar     return getSource();
233066d786cSAart Bik   return {};
234066d786cSAart Bik }
235066d786cSAart Bik 
verify()236b98dc035SRiver Riddle LogicalResult ToPointersOp::verify() {
23704235d07SJacques Pienaar   auto e = getSparseTensorEncoding(getTensor().getType());
23804235d07SJacques Pienaar   if (failed(isInBounds(getDim(), getTensor())))
239b98dc035SRiver Riddle     return emitError("requested pointers dimension out of bounds");
24004235d07SJacques Pienaar   if (failed(isMatchingWidth(getResult(), e.getPointerBitWidth())))
241b98dc035SRiver Riddle     return emitError("unexpected type for pointers");
24296a23911SAart Bik   return success();
24396a23911SAart Bik }
24496a23911SAart Bik 
verify()245b98dc035SRiver Riddle LogicalResult ToIndicesOp::verify() {
24604235d07SJacques Pienaar   auto e = getSparseTensorEncoding(getTensor().getType());
24704235d07SJacques Pienaar   if (failed(isInBounds(getDim(), getTensor())))
248b98dc035SRiver Riddle     return emitError("requested indices dimension out of bounds");
24904235d07SJacques Pienaar   if (failed(isMatchingWidth(getResult(), e.getIndexBitWidth())))
250b98dc035SRiver Riddle     return emitError("unexpected type for indices");
25196a23911SAart Bik   return success();
25296a23911SAart Bik }
25396a23911SAart Bik 
verify()254b98dc035SRiver Riddle LogicalResult ToValuesOp::verify() {
25504235d07SJacques Pienaar   RankedTensorType ttp = getTensor().getType().cast<RankedTensorType>();
25604235d07SJacques Pienaar   MemRefType mtp = getResult().getType().cast<MemRefType>();
25796a23911SAart Bik   if (ttp.getElementType() != mtp.getElementType())
258b98dc035SRiver Riddle     return emitError("unexpected mismatch in element types");
25996a23911SAart Bik   return success();
26096a23911SAart Bik }
26196a23911SAart Bik 
262f66e5769SAart Bik //===----------------------------------------------------------------------===//
263414ed019SJim Kitchen // TensorDialect Linalg.Generic Operations.
264414ed019SJim Kitchen //===----------------------------------------------------------------------===//
265414ed019SJim Kitchen 
266414ed019SJim Kitchen template <class T>
verifyNumBlockArgs(T * op,Region & region,const char * regionName,TypeRange inputTypes,Type outputType)267414ed019SJim Kitchen static LogicalResult verifyNumBlockArgs(T *op, Region &region,
268414ed019SJim Kitchen                                         const char *regionName,
269414ed019SJim Kitchen                                         TypeRange inputTypes, Type outputType) {
270414ed019SJim Kitchen   unsigned numArgs = region.getNumArguments();
271414ed019SJim Kitchen   unsigned expectedNum = inputTypes.size();
272414ed019SJim Kitchen   if (numArgs != expectedNum)
273414ed019SJim Kitchen     return op->emitError() << regionName << " region must have exactly "
274414ed019SJim Kitchen                            << expectedNum << " arguments";
275414ed019SJim Kitchen 
276414ed019SJim Kitchen   for (unsigned i = 0; i < numArgs; i++) {
277414ed019SJim Kitchen     Type typ = region.getArgument(i).getType();
278414ed019SJim Kitchen     if (typ != inputTypes[i])
279414ed019SJim Kitchen       return op->emitError() << regionName << " region argument " << (i + 1)
280414ed019SJim Kitchen                              << " type mismatch";
281414ed019SJim Kitchen   }
282414ed019SJim Kitchen   Operation *term = region.front().getTerminator();
283414ed019SJim Kitchen   YieldOp yield = dyn_cast<YieldOp>(term);
284414ed019SJim Kitchen   if (!yield)
285414ed019SJim Kitchen     return op->emitError() << regionName
286414ed019SJim Kitchen                            << " region must end with sparse_tensor.yield";
287414ed019SJim Kitchen   if (yield.getOperand().getType() != outputType)
288414ed019SJim Kitchen     return op->emitError() << regionName << " region yield type mismatch";
289414ed019SJim Kitchen 
290414ed019SJim Kitchen   return success();
291414ed019SJim Kitchen }
292414ed019SJim Kitchen 
verify()293414ed019SJim Kitchen LogicalResult BinaryOp::verify() {
294414ed019SJim Kitchen   NamedAttrList attrs = (*this)->getAttrs();
29504235d07SJacques Pienaar   Type leftType = getX().getType();
29604235d07SJacques Pienaar   Type rightType = getY().getType();
29704235d07SJacques Pienaar   Type outputType = getOutput().getType();
29804235d07SJacques Pienaar   Region &overlap = getOverlapRegion();
29904235d07SJacques Pienaar   Region &left = getLeftRegion();
30004235d07SJacques Pienaar   Region &right = getRightRegion();
301414ed019SJim Kitchen 
302414ed019SJim Kitchen   // Check correct number of block arguments and return type for each
303414ed019SJim Kitchen   // non-empty region.
304414ed019SJim Kitchen   LogicalResult regionResult = success();
305414ed019SJim Kitchen   if (!overlap.empty()) {
306414ed019SJim Kitchen     regionResult = verifyNumBlockArgs(
307414ed019SJim Kitchen         this, overlap, "overlap", TypeRange{leftType, rightType}, outputType);
308414ed019SJim Kitchen     if (failed(regionResult))
309414ed019SJim Kitchen       return regionResult;
310414ed019SJim Kitchen   }
311414ed019SJim Kitchen   if (!left.empty()) {
312414ed019SJim Kitchen     regionResult =
313414ed019SJim Kitchen         verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType);
314414ed019SJim Kitchen     if (failed(regionResult))
315414ed019SJim Kitchen       return regionResult;
31604235d07SJacques Pienaar   } else if (getLeftIdentity()) {
317414ed019SJim Kitchen     if (leftType != outputType)
318414ed019SJim Kitchen       return emitError("left=identity requires first argument to have the same "
319414ed019SJim Kitchen                        "type as the output");
320414ed019SJim Kitchen   }
321414ed019SJim Kitchen   if (!right.empty()) {
322414ed019SJim Kitchen     regionResult = verifyNumBlockArgs(this, right, "right",
323414ed019SJim Kitchen                                       TypeRange{rightType}, outputType);
324414ed019SJim Kitchen     if (failed(regionResult))
325414ed019SJim Kitchen       return regionResult;
32604235d07SJacques Pienaar   } else if (getRightIdentity()) {
327414ed019SJim Kitchen     if (rightType != outputType)
328414ed019SJim Kitchen       return emitError("right=identity requires second argument to have the "
329414ed019SJim Kitchen                        "same type as the output");
330414ed019SJim Kitchen   }
331414ed019SJim Kitchen 
332414ed019SJim Kitchen   return success();
333414ed019SJim Kitchen }
334414ed019SJim Kitchen 
verify()335414ed019SJim Kitchen LogicalResult UnaryOp::verify() {
33604235d07SJacques Pienaar   Type inputType = getX().getType();
33704235d07SJacques Pienaar   Type outputType = getOutput().getType();
338414ed019SJim Kitchen   LogicalResult regionResult = success();
339414ed019SJim Kitchen 
340414ed019SJim Kitchen   // Check correct number of block arguments and return type for each
341414ed019SJim Kitchen   // non-empty region.
34204235d07SJacques Pienaar   Region &present = getPresentRegion();
343414ed019SJim Kitchen   if (!present.empty()) {
344414ed019SJim Kitchen     regionResult = verifyNumBlockArgs(this, present, "present",
345414ed019SJim Kitchen                                       TypeRange{inputType}, outputType);
346414ed019SJim Kitchen     if (failed(regionResult))
347414ed019SJim Kitchen       return regionResult;
348414ed019SJim Kitchen   }
34904235d07SJacques Pienaar   Region &absent = getAbsentRegion();
350414ed019SJim Kitchen   if (!absent.empty()) {
351414ed019SJim Kitchen     regionResult =
352414ed019SJim Kitchen         verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType);
353414ed019SJim Kitchen     if (failed(regionResult))
354414ed019SJim Kitchen       return regionResult;
355414ed019SJim Kitchen   }
356414ed019SJim Kitchen 
357414ed019SJim Kitchen   return success();
358414ed019SJim Kitchen }
359414ed019SJim Kitchen 
verify()3602b8a4d9cSJim Kitchen LogicalResult ReduceOp::verify() {
361*a1ec0d8bSJacques Pienaar   Type inputType = getX().getType();
3622b8a4d9cSJim Kitchen   LogicalResult regionResult = success();
3632b8a4d9cSJim Kitchen 
3642b8a4d9cSJim Kitchen   // Check correct number of block arguments and return type.
365*a1ec0d8bSJacques Pienaar   Region &formula = getRegion();
3662b8a4d9cSJim Kitchen   if (!formula.empty()) {
3672b8a4d9cSJim Kitchen     regionResult = verifyNumBlockArgs(
3682b8a4d9cSJim Kitchen         this, formula, "reduce", TypeRange{inputType, inputType}, inputType);
3692b8a4d9cSJim Kitchen     if (failed(regionResult))
3702b8a4d9cSJim Kitchen       return regionResult;
3712b8a4d9cSJim Kitchen   }
3722b8a4d9cSJim Kitchen 
3732b8a4d9cSJim Kitchen   return success();
3742b8a4d9cSJim Kitchen }
3752b8a4d9cSJim Kitchen 
verify()376414ed019SJim Kitchen LogicalResult YieldOp::verify() {
377414ed019SJim Kitchen   // Check for compatible parent.
378414ed019SJim Kitchen   auto *parentOp = (*this)->getParentOp();
3792b8a4d9cSJim Kitchen   if (isa<BinaryOp>(parentOp) || isa<UnaryOp>(parentOp) ||
3802b8a4d9cSJim Kitchen       isa<ReduceOp>(parentOp))
381414ed019SJim Kitchen     return success();
382414ed019SJim Kitchen 
3832b8a4d9cSJim Kitchen   return emitOpError(
3842b8a4d9cSJim Kitchen       "expected parent op to be sparse_tensor unary, binary, or reduce");
385414ed019SJim Kitchen }
386414ed019SJim Kitchen 
387414ed019SJim Kitchen //===----------------------------------------------------------------------===//
38896a23911SAart Bik // TensorDialect Methods.
3890a292199SAart Bik //===----------------------------------------------------------------------===//
3900a292199SAart Bik 
initialize()391319072f4SAart Bik void SparseTensorDialect::initialize() {
3920a292199SAart Bik   addAttributes<
3930a292199SAart Bik #define GET_ATTRDEF_LIST
3940a292199SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
3950a292199SAart Bik       >();
396319072f4SAart Bik   addOperations<
397319072f4SAart Bik #define GET_OP_LIST
398319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
399319072f4SAart Bik       >();
400319072f4SAart Bik }
401319072f4SAart Bik 
402319072f4SAart Bik #define GET_OP_CLASSES
403319072f4SAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc"
4045517208dSAart Bik 
4055517208dSAart Bik #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc"
406