1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "CodegenUtils.h"
10
11 #include "mlir/IR/Types.h"
12 #include "mlir/IR/Value.h"
13
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
16
17 //===----------------------------------------------------------------------===//
18 // ExecutionEngine/SparseTensorUtils helper functions.
19 //===----------------------------------------------------------------------===//
20
overheadTypeEncoding(unsigned width)21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
22 switch (width) {
23 case 64:
24 return OverheadType::kU64;
25 case 32:
26 return OverheadType::kU32;
27 case 16:
28 return OverheadType::kU16;
29 case 8:
30 return OverheadType::kU8;
31 case 0:
32 return OverheadType::kIndex;
33 }
34 llvm_unreachable("Unsupported overhead bitwidth");
35 }
36
overheadTypeEncoding(Type tp)37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
38 if (tp.isIndex())
39 return OverheadType::kIndex;
40 if (auto intTp = tp.dyn_cast<IntegerType>())
41 return overheadTypeEncoding(intTp.getWidth());
42 llvm_unreachable("Unknown overhead type");
43 }
44
getOverheadType(Builder & builder,OverheadType ot)45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
46 switch (ot) {
47 case OverheadType::kIndex:
48 return builder.getIndexType();
49 case OverheadType::kU64:
50 return builder.getIntegerType(64);
51 case OverheadType::kU32:
52 return builder.getIntegerType(32);
53 case OverheadType::kU16:
54 return builder.getIntegerType(16);
55 case OverheadType::kU8:
56 return builder.getIntegerType(8);
57 }
58 llvm_unreachable("Unknown OverheadType");
59 }
60
pointerOverheadTypeEncoding(const SparseTensorEncodingAttr & enc)61 OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding(
62 const SparseTensorEncodingAttr &enc) {
63 return overheadTypeEncoding(enc.getPointerBitWidth());
64 }
65
indexOverheadTypeEncoding(const SparseTensorEncodingAttr & enc)66 OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding(
67 const SparseTensorEncodingAttr &enc) {
68 return overheadTypeEncoding(enc.getIndexBitWidth());
69 }
70
getPointerOverheadType(Builder & builder,const SparseTensorEncodingAttr & enc)71 Type mlir::sparse_tensor::getPointerOverheadType(
72 Builder &builder, const SparseTensorEncodingAttr &enc) {
73 return getOverheadType(builder, pointerOverheadTypeEncoding(enc));
74 }
75
getIndexOverheadType(Builder & builder,const SparseTensorEncodingAttr & enc)76 Type mlir::sparse_tensor::getIndexOverheadType(
77 Builder &builder, const SparseTensorEncodingAttr &enc) {
78 return getOverheadType(builder, indexOverheadTypeEncoding(enc));
79 }
80
81 // TODO: Adjust the naming convention for the constructors of
82 // `OverheadType` so we can use the `FOREVERY_O` x-macro here instead
83 // of `FOREVERY_FIXED_O`; to further reduce the possibility of typo bugs
84 // or things getting out of sync.
overheadTypeFunctionSuffix(OverheadType ot)85 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
86 switch (ot) {
87 case OverheadType::kIndex:
88 return "0";
89 #define CASE(ONAME, O) \
90 case OverheadType::kU##ONAME: \
91 return #ONAME;
92 FOREVERY_FIXED_O(CASE)
93 #undef CASE
94 }
95 llvm_unreachable("Unknown OverheadType");
96 }
97
overheadTypeFunctionSuffix(Type tp)98 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
99 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
100 }
101
primaryTypeEncoding(Type elemTp)102 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
103 if (elemTp.isF64())
104 return PrimaryType::kF64;
105 if (elemTp.isF32())
106 return PrimaryType::kF32;
107 if (elemTp.isF16())
108 return PrimaryType::kF16;
109 if (elemTp.isBF16())
110 return PrimaryType::kBF16;
111 if (elemTp.isInteger(64))
112 return PrimaryType::kI64;
113 if (elemTp.isInteger(32))
114 return PrimaryType::kI32;
115 if (elemTp.isInteger(16))
116 return PrimaryType::kI16;
117 if (elemTp.isInteger(8))
118 return PrimaryType::kI8;
119 if (auto complexTp = elemTp.dyn_cast<ComplexType>()) {
120 auto complexEltTp = complexTp.getElementType();
121 if (complexEltTp.isF64())
122 return PrimaryType::kC64;
123 if (complexEltTp.isF32())
124 return PrimaryType::kC32;
125 }
126 llvm_unreachable("Unknown primary type");
127 }
128
primaryTypeFunctionSuffix(PrimaryType pt)129 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
130 switch (pt) {
131 #define CASE(VNAME, V) \
132 case PrimaryType::k##VNAME: \
133 return #VNAME;
134 FOREVERY_V(CASE)
135 #undef CASE
136 }
137 llvm_unreachable("Unknown PrimaryType");
138 }
139
primaryTypeFunctionSuffix(Type elemTp)140 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
141 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
142 }
143
dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt)144 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
145 SparseTensorEncodingAttr::DimLevelType dlt) {
146 switch (dlt) {
147 case SparseTensorEncodingAttr::DimLevelType::Dense:
148 return DimLevelType::kDense;
149 case SparseTensorEncodingAttr::DimLevelType::Compressed:
150 return DimLevelType::kCompressed;
151 case SparseTensorEncodingAttr::DimLevelType::Singleton:
152 return DimLevelType::kSingleton;
153 }
154 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
155 }
156
157 //===----------------------------------------------------------------------===//
158 // Misc code generators.
159 //===----------------------------------------------------------------------===//
160
getOneAttr(Builder & builder,Type tp)161 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
162 if (tp.isa<FloatType>())
163 return builder.getFloatAttr(tp, 1.0);
164 if (tp.isa<IndexType>())
165 return builder.getIndexAttr(1);
166 if (auto intTp = tp.dyn_cast<IntegerType>())
167 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
168 if (tp.isa<RankedTensorType, VectorType>()) {
169 auto shapedTp = tp.cast<ShapedType>();
170 if (auto one = getOneAttr(builder, shapedTp.getElementType()))
171 return DenseElementsAttr::get(shapedTp, one);
172 }
173 llvm_unreachable("Unsupported attribute type");
174 }
175
genIsNonzero(OpBuilder & builder,mlir::Location loc,Value v)176 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
177 Value v) {
178 Type tp = v.getType();
179 Value zero = constantZero(builder, loc, tp);
180 if (tp.isa<FloatType>())
181 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
182 zero);
183 if (tp.isIntOrIndex())
184 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
185 zero);
186 if (tp.dyn_cast<ComplexType>())
187 return builder.create<complex::NotEqualOp>(loc, v, zero);
188 llvm_unreachable("Non-numeric type");
189 }
190