1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 
11 #include "mlir/IR/Types.h"
12 #include "mlir/IR/Value.h"
13 
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
16 
17 //===----------------------------------------------------------------------===//
18 // ExecutionEngine/SparseTensorUtils helper functions.
19 //===----------------------------------------------------------------------===//
20 
21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
22   switch (width) {
23   case 64:
24     return OverheadType::kU64;
25   case 32:
26     return OverheadType::kU32;
27   case 16:
28     return OverheadType::kU16;
29   case 8:
30     return OverheadType::kU8;
31   case 0:
32     return OverheadType::kIndex;
33   }
34   llvm_unreachable("Unsupported overhead bitwidth");
35 }
36 
37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
38   if (tp.isIndex())
39     return OverheadType::kIndex;
40   if (auto intTp = tp.dyn_cast<IntegerType>())
41     return overheadTypeEncoding(intTp.getWidth());
42   llvm_unreachable("Unknown overhead type");
43 }
44 
45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
46   switch (ot) {
47   case OverheadType::kIndex:
48     return builder.getIndexType();
49   case OverheadType::kU64:
50     return builder.getIntegerType(64);
51   case OverheadType::kU32:
52     return builder.getIntegerType(32);
53   case OverheadType::kU16:
54     return builder.getIntegerType(16);
55   case OverheadType::kU8:
56     return builder.getIntegerType(8);
57   }
58   llvm_unreachable("Unknown OverheadType");
59 }
60 
61 OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding(
62     const SparseTensorEncodingAttr &enc) {
63   return overheadTypeEncoding(enc.getPointerBitWidth());
64 }
65 
66 OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding(
67     const SparseTensorEncodingAttr &enc) {
68   return overheadTypeEncoding(enc.getIndexBitWidth());
69 }
70 
71 Type mlir::sparse_tensor::getPointerOverheadType(
72     Builder &builder, const SparseTensorEncodingAttr &enc) {
73   return getOverheadType(builder, pointerOverheadTypeEncoding(enc));
74 }
75 
76 Type mlir::sparse_tensor::getIndexOverheadType(
77     Builder &builder, const SparseTensorEncodingAttr &enc) {
78   return getOverheadType(builder, indexOverheadTypeEncoding(enc));
79 }
80 
81 // TODO: Adjust the naming convention for the constructors of `OverheadType`
82 // and the function-suffix for `kIndex` so we can use the `FOREVERY_O`
83 // x-macro here instead of `FOREVERY_FIXED_O`; to further reduce the
84 // possibility of typo bugs or things getting out of sync.
85 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
86   switch (ot) {
87   case OverheadType::kIndex:
88     return "";
89 #define CASE(ONAME, O)                                                         \
90   case OverheadType::kU##ONAME:                                                \
91     return #ONAME;
92     FOREVERY_FIXED_O(CASE)
93 #undef CASE
94   }
95   llvm_unreachable("Unknown OverheadType");
96 }
97 
98 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
99   return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
100 }
101 
102 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
103   if (elemTp.isF64())
104     return PrimaryType::kF64;
105   if (elemTp.isF32())
106     return PrimaryType::kF32;
107   if (elemTp.isInteger(64))
108     return PrimaryType::kI64;
109   if (elemTp.isInteger(32))
110     return PrimaryType::kI32;
111   if (elemTp.isInteger(16))
112     return PrimaryType::kI16;
113   if (elemTp.isInteger(8))
114     return PrimaryType::kI8;
115   if (auto complexTp = elemTp.dyn_cast<ComplexType>()) {
116     auto complexEltTp = complexTp.getElementType();
117     if (complexEltTp.isF64())
118       return PrimaryType::kC64;
119     if (complexEltTp.isF32())
120       return PrimaryType::kC32;
121   }
122   llvm_unreachable("Unknown primary type");
123 }
124 
125 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
126   switch (pt) {
127 #define CASE(VNAME, V)                                                         \
128   case PrimaryType::k##VNAME:                                                  \
129     return #VNAME;
130     FOREVERY_V(CASE)
131 #undef CASE
132   }
133   llvm_unreachable("Unknown PrimaryType");
134 }
135 
136 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
137   return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
138 }
139 
140 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
141     SparseTensorEncodingAttr::DimLevelType dlt) {
142   switch (dlt) {
143   case SparseTensorEncodingAttr::DimLevelType::Dense:
144     return DimLevelType::kDense;
145   case SparseTensorEncodingAttr::DimLevelType::Compressed:
146     return DimLevelType::kCompressed;
147   case SparseTensorEncodingAttr::DimLevelType::Singleton:
148     return DimLevelType::kSingleton;
149   }
150   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // Misc code generators.
155 //===----------------------------------------------------------------------===//
156 
157 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
158   if (tp.isa<FloatType>())
159     return builder.getFloatAttr(tp, 1.0);
160   if (tp.isa<IndexType>())
161     return builder.getIndexAttr(1);
162   if (auto intTp = tp.dyn_cast<IntegerType>())
163     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
164   if (tp.isa<RankedTensorType, VectorType>()) {
165     auto shapedTp = tp.cast<ShapedType>();
166     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
167       return DenseElementsAttr::get(shapedTp, one);
168   }
169   llvm_unreachable("Unsupported attribute type");
170 }
171 
172 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
173                                         Value v) {
174   Type tp = v.getType();
175   Value zero = constantZero(builder, loc, tp);
176   if (tp.isa<FloatType>())
177     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
178                                          zero);
179   if (tp.isIntOrIndex())
180     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
181                                          zero);
182   if (tp.dyn_cast<ComplexType>())
183     return builder.create<complex::NotEqualOp>(loc, v, zero);
184   llvm_unreachable("Non-numeric type");
185 }
186