1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for generating MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
15
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/Complex/IR/Complex.h"
18 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
19 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
20 #include "mlir/IR/Builders.h"
21
22 namespace mlir {
23 class Location;
24 class Type;
25 class Value;
26
27 namespace sparse_tensor {
28
29 //===----------------------------------------------------------------------===//
30 // ExecutionEngine/SparseTensorUtils helper functions.
31 //===----------------------------------------------------------------------===//
32
33 /// Converts an overhead storage bitwidth to its internal type-encoding.
34 OverheadType overheadTypeEncoding(unsigned width);
35
36 /// Converts an overhead storage type to its internal type-encoding.
37 OverheadType overheadTypeEncoding(Type tp);
38
39 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
40 Type getOverheadType(Builder &builder, OverheadType ot);
41
42 /// Returns the OverheadType for pointer overhead storage.
43 OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
44
45 /// Returns the OverheadType for index overhead storage.
46 OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
47
48 /// Returns the mlir::Type for pointer overhead storage.
49 Type getPointerOverheadType(Builder &builder,
50 const SparseTensorEncodingAttr &enc);
51
52 /// Returns the mlir::Type for index overhead storage.
53 Type getIndexOverheadType(Builder &builder,
54 const SparseTensorEncodingAttr &enc);
55
56 /// Convert OverheadType to its function-name suffix.
57 StringRef overheadTypeFunctionSuffix(OverheadType ot);
58
59 /// Converts an overhead storage type to its function-name suffix.
60 StringRef overheadTypeFunctionSuffix(Type overheadTp);
61
62 /// Converts a primary storage type to its internal type-encoding.
63 PrimaryType primaryTypeEncoding(Type elemTp);
64
65 /// Convert PrimaryType to its function-name suffix.
66 StringRef primaryTypeFunctionSuffix(PrimaryType pt);
67
68 /// Converts a primary storage type to its function-name suffix.
69 StringRef primaryTypeFunctionSuffix(Type elemTp);
70
71 /// Converts the IR's dimension level type to its internal type-encoding.
72 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
73
74 //===----------------------------------------------------------------------===//
75 // Misc code generators.
76 //
77 // TODO: both of these should move upstream to their respective classes.
78 // Once RFCs have been created for those changes, list them here.
79 //===----------------------------------------------------------------------===//
80
81 /// Generates a 1-valued attribute of the given type. This supports
82 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
83 /// for unsupported types we raise `llvm_unreachable` rather than
84 /// returning a null attribute.
85 Attribute getOneAttr(Builder &builder, Type tp);
86
87 /// Generates the comparison `v != 0` where `v` is of numeric type.
88 /// For floating types, we use the "unordered" comparator (i.e., returns
89 /// true if `v` is NaN).
90 Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
91
92 //===----------------------------------------------------------------------===//
93 // Constant generators.
94 //
95 // All these functions are just wrappers to improve code legibility;
96 // therefore, we mark them as `inline` to avoid introducing any additional
97 // overhead due to the legibility.
98 //
99 // TODO: Ideally these should move upstream, so that we don't
100 // develop a design island. However, doing so will involve
101 // substantial design work. For related prior discussion, see
102 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879>
103 //===----------------------------------------------------------------------===//
104
105 /// Generates a 0-valued constant of the given type. In addition to
106 /// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
107 /// this also works for `RankedTensorType` and `VectorType` (for which it
108 /// generates a constant `DenseElementsAttr` of zeros).
constantZero(OpBuilder & builder,Location loc,Type tp)109 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
110 if (auto ctp = tp.dyn_cast<ComplexType>()) {
111 auto zeroe = builder.getZeroAttr(ctp.getElementType());
112 auto zeroa = builder.getArrayAttr({zeroe, zeroe});
113 return builder.create<complex::ConstantOp>(loc, tp, zeroa);
114 }
115 return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
116 }
117
118 /// Generates a 1-valued constant of the given type. This supports all
119 /// the same types as `constantZero`.
constantOne(OpBuilder & builder,Location loc,Type tp)120 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
121 if (auto ctp = tp.dyn_cast<ComplexType>()) {
122 auto zeroe = builder.getZeroAttr(ctp.getElementType());
123 auto onee = getOneAttr(builder, ctp.getElementType());
124 auto zeroa = builder.getArrayAttr({onee, zeroe});
125 return builder.create<complex::ConstantOp>(loc, tp, zeroa);
126 }
127 return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
128 }
129
130 /// Generates a constant of `index` type.
constantIndex(OpBuilder & builder,Location loc,int64_t i)131 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
132 return builder.create<arith::ConstantIndexOp>(loc, i);
133 }
134
135 /// Generates a constant of `i32` type.
constantI32(OpBuilder & builder,Location loc,int32_t i)136 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
137 return builder.create<arith::ConstantIntOp>(loc, i, 32);
138 }
139
140 /// Generates a constant of `i16` type.
constantI16(OpBuilder & builder,Location loc,int16_t i)141 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
142 return builder.create<arith::ConstantIntOp>(loc, i, 16);
143 }
144
145 /// Generates a constant of `i8` type.
constantI8(OpBuilder & builder,Location loc,int8_t i)146 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
147 return builder.create<arith::ConstantIntOp>(loc, i, 8);
148 }
149
150 /// Generates a constant of `i1` type.
constantI1(OpBuilder & builder,Location loc,bool b)151 inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
152 return builder.create<arith::ConstantIntOp>(loc, b, 1);
153 }
154
155 /// Generates a constant of the given `Action`.
constantAction(OpBuilder & builder,Location loc,Action action)156 inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
157 return constantI32(builder, loc, static_cast<uint32_t>(action));
158 }
159
160 /// Generates a constant of the internal type-encoding for overhead storage.
constantOverheadTypeEncoding(OpBuilder & builder,Location loc,unsigned width)161 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
162 unsigned width) {
163 return constantI32(builder, loc,
164 static_cast<uint32_t>(overheadTypeEncoding(width)));
165 }
166
167 /// Generates a constant of the internal type-encoding for pointer
168 /// overhead storage.
constantPointerTypeEncoding(OpBuilder & builder,Location loc,const SparseTensorEncodingAttr & enc)169 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc,
170 const SparseTensorEncodingAttr &enc) {
171 return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth());
172 }
173
174 /// Generates a constant of the internal type-encoding for index overhead
175 /// storage.
constantIndexTypeEncoding(OpBuilder & builder,Location loc,const SparseTensorEncodingAttr & enc)176 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc,
177 const SparseTensorEncodingAttr &enc) {
178 return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth());
179 }
180
181 /// Generates a constant of the internal type-encoding for primary storage.
constantPrimaryTypeEncoding(OpBuilder & builder,Location loc,Type elemTp)182 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
183 Type elemTp) {
184 return constantI32(builder, loc,
185 static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
186 }
187
188 /// Generates a constant of the internal dimension level type encoding.
189 inline Value
constantDimLevelTypeEncoding(OpBuilder & builder,Location loc,SparseTensorEncodingAttr::DimLevelType dlt)190 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
191 SparseTensorEncodingAttr::DimLevelType dlt) {
192 return constantI8(builder, loc,
193 static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
194 }
195
196 } // namespace sparse_tensor
197 } // namespace mlir
198
199 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
200