1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Tensor/IR/Tensor.h"
10 #include "mlir/IR/DialectImplementation.h"
11 #include "mlir/Transforms/InliningUtils.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 using namespace mlir;
15 using namespace mlir::tensor;
16 
17 //===----------------------------------------------------------------------===//
18 // TableGen'd Attributes Methods
19 //===----------------------------------------------------------------------===//
20 
21 #define GET_ATTRDEF_CLASSES
22 #include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
23 
24 // Dictionary keys.
25 static constexpr StringRef getSparseDimLevelTypeAttrName() {
26   return "sparseDimLevelType";
27 }
28 static constexpr StringRef getSparseDimOrderingAttrName() {
29   return "sparseDimOrdering";
30 }
31 static constexpr StringRef getSparsePointerBitWidthAttrName() {
32   return "sparsePointerBitWidth";
33 }
34 static constexpr StringRef getSparseIndexBitWidthAttrName() {
35   return "sparseIndexBitWidth";
36 }
37 
38 // Dictionary values.
39 static constexpr StringRef getDenseDimLevelTypeVal() { return "dense"; }
40 static constexpr StringRef getCompressedDimLevelTypeVal() {
41   return "compressed";
42 }
43 static constexpr StringRef getSingletonDimLevelTypeVal() { return "singleton"; }
44 
45 Attribute SparseTensorEncodingAttr::parse(MLIRContext *context,
46                                           DialectAsmParser &parser, Type type) {
47   if (failed(parser.parseLess()))
48     return {};
49   DictionaryAttr dict;
50   if (failed(parser.parseAttribute(dict)))
51     return {};
52   if (failed(parser.parseGreater()))
53     return {};
54   return SparseTensorEncodingAttr::get(context, dict);
55 }
56 
57 void SparseTensorEncodingAttr::print(DialectAsmPrinter &printer) const {
58   printer << "sparse<" << getDict() << ">";
59 }
60 
61 LogicalResult SparseTensorEncodingAttr::verifyEncoding(
62     llvm::ArrayRef<int64_t> shape, Type elementType,
63     llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
64   unsigned size = shape.size();
65   for (const NamedAttribute &attr : getDict()) {
66     if (attr.first == getSparseDimLevelTypeAttrName()) {
67       // Dimension level type verification.
68       auto arrayAttr = attr.second.dyn_cast<ArrayAttr>();
69       if (!arrayAttr || size != static_cast<int64_t>(arrayAttr.size()))
70         return emitError() << "expected an array of size " << size
71                            << " for dimension level types";
72       for (unsigned i = 0; i < size; i++) {
73         auto strAttr = arrayAttr[i].dyn_cast<StringAttr>();
74         if (!strAttr)
75           return emitError()
76                  << "expected string value in dimension level types";
77         auto strVal = strAttr.getValue();
78         if (strVal != getDenseDimLevelTypeVal() &&
79             strVal != getCompressedDimLevelTypeVal() &&
80             strVal != getSingletonDimLevelTypeVal())
81           return emitError() << "unexpected dimension level type: " << strAttr;
82       }
83     } else if (attr.first == getSparseDimOrderingAttrName()) {
84       // Dimension order verification.
85       auto affineAttr = attr.second.dyn_cast<AffineMapAttr>();
86       if (!affineAttr)
87         return emitError() << "expected an affine map for dimension ordering";
88       AffineMap map = affineAttr.getValue();
89       if (size != map.getNumResults() || !map.isPermutation())
90         return emitError() << "expected a permutation affine map of size "
91                            << size << " for dimension ordering";
92     } else if (attr.first == getSparsePointerBitWidthAttrName() ||
93                attr.first == getSparseIndexBitWidthAttrName()) {
94       // Pointer or index bitwidth verification.
95       auto intAttr = attr.second.dyn_cast<IntegerAttr>();
96       if (!intAttr)
97         return emitError() << "expected an integral bitwidth";
98       switch (intAttr.getInt()) {
99       case 0:
100       case 8:
101       case 16:
102       case 32:
103       case 64:
104         continue;
105       default:
106         return emitError() << "unexpected bitwidth: " << intAttr.getInt();
107       }
108     } else {
109       return emitError() << "unexpected key: " << attr.first.str();
110     }
111   }
112   return success();
113 }
114 
115 SparseTensorEncodingAttr::DimLevelType
116 SparseTensorEncodingAttr::getDimLevelType(unsigned dim) const {
117   if (auto value = getDict().get(getSparseDimLevelTypeAttrName())) {
118     auto strVal =
119         value.dyn_cast<ArrayAttr>()[dim].cast<StringAttr>().getValue();
120     if (strVal == getCompressedDimLevelTypeVal())
121       return DimLevelType::Compressed;
122     if (strVal == getSingletonDimLevelTypeVal())
123       return DimLevelType::Singleton;
124   }
125   return DimLevelType::Dense;
126 }
127 
128 AffineMap SparseTensorEncodingAttr::getDimOrdering() const {
129   if (auto value = getDict().get(getSparseDimOrderingAttrName()))
130     return value.cast<AffineMapAttr>().getValue();
131   return {};
132 }
133 
134 unsigned SparseTensorEncodingAttr::getPointerBitWidth() const {
135   if (auto value = getDict().get(getSparsePointerBitWidthAttrName()))
136     return value.cast<IntegerAttr>().getInt();
137   return 0;
138 }
139 
140 unsigned SparseTensorEncodingAttr::getIndexBitWidth() const {
141   if (auto value = getDict().get(getSparseIndexBitWidthAttrName()))
142     return value.cast<IntegerAttr>().getInt();
143   return 0;
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // TensorDialect Dialect Interfaces
148 //===----------------------------------------------------------------------===//
149 
150 namespace {
151 struct TensorInlinerInterface : public DialectInlinerInterface {
152   using DialectInlinerInterface::DialectInlinerInterface;
153   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
154                        BlockAndValueMapping &valueMapping) const final {
155     return true;
156   }
157   bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
158                        BlockAndValueMapping &) const final {
159     return true;
160   }
161 };
162 } // end anonymous namespace
163 
164 //===----------------------------------------------------------------------===//
165 // TensorDialect Methods
166 //===----------------------------------------------------------------------===//
167 
168 void TensorDialect::initialize() {
169   addAttributes<
170 #define GET_ATTRDEF_LIST
171 #include "mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
172       >();
173   addOperations<
174 #define GET_OP_LIST
175 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
176       >();
177   addInterfaces<TensorInlinerInterface>();
178 }
179 
180 Attribute TensorDialect::parseAttribute(DialectAsmParser &parser,
181                                         Type type) const {
182   StringRef attrTag;
183   if (failed(parser.parseKeyword(&attrTag)))
184     return Attribute();
185   Attribute attr;
186   auto parseResult =
187       generatedAttributeParser(getContext(), parser, attrTag, type, attr);
188   if (parseResult.hasValue())
189     return attr;
190   parser.emitError(parser.getNameLoc(), "unknown tensor attribute");
191   return Attribute();
192 }
193 
194 void TensorDialect::printAttribute(::mlir::Attribute attr,
195                                    ::mlir::DialectAsmPrinter &printer) const {
196   if (succeeded(generatedAttributePrinter(attr, printer)))
197     return;
198 }
199