1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for generating MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
15 
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
19 #include "mlir/IR/Builders.h"
20 
21 namespace mlir {
22 class Location;
23 class Type;
24 class Value;
25 
26 namespace sparse_tensor {
27 
28 //===----------------------------------------------------------------------===//
29 // ExecutionEngine/SparseTensorUtils helper functions.
30 //===----------------------------------------------------------------------===//
31 
32 /// Converts an overhead storage bitwidth to its internal type-encoding.
33 OverheadType overheadTypeEncoding(unsigned width);
34 
35 /// Converts an overhead storage type to its internal type-encoding.
36 OverheadType overheadTypeEncoding(Type tp);
37 
38 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
39 Type getOverheadType(Builder &builder, OverheadType ot);
40 
41 /// Returns the OverheadType for pointer overhead storage.
42 OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
43 
44 /// Returns the OverheadType for index overhead storage.
45 OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
46 
47 /// Returns the mlir::Type for pointer overhead storage.
48 Type getPointerOverheadType(Builder &builder,
49                             const SparseTensorEncodingAttr &enc);
50 
51 /// Returns the mlir::Type for index overhead storage.
52 Type getIndexOverheadType(Builder &builder,
53                           const SparseTensorEncodingAttr &enc);
54 
55 /// Convert OverheadType to its function-name suffix.
56 StringRef overheadTypeFunctionSuffix(OverheadType ot);
57 
58 /// Converts an overhead storage type to its function-name suffix.
59 StringRef overheadTypeFunctionSuffix(Type overheadTp);
60 
61 /// Converts a primary storage type to its internal type-encoding.
62 PrimaryType primaryTypeEncoding(Type elemTp);
63 
64 /// Convert PrimaryType to its function-name suffix.
65 StringRef primaryTypeFunctionSuffix(PrimaryType pt);
66 
67 /// Converts a primary storage type to its function-name suffix.
68 StringRef primaryTypeFunctionSuffix(Type elemTp);
69 
70 /// Converts the IR's dimension level type to its internal type-encoding.
71 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
72 
73 //===----------------------------------------------------------------------===//
74 // Misc code generators.
75 //
76 // TODO: both of these should move upstream to their respective classes.
77 // Once RFCs have been created for those changes, list them here.
78 //===----------------------------------------------------------------------===//
79 
80 /// Generates a 1-valued attribute of the given type.  This supports
81 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
82 /// for unsupported types we raise `llvm_unreachable` rather than
83 /// returning a null attribute.
84 Attribute getOneAttr(Builder &builder, Type tp);
85 
86 /// Generates the comparison `v != 0` where `v` is of numeric type.
87 /// For floating types, we use the "unordered" comparator (i.e., returns
88 /// true if `v` is NaN).
89 Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
90 
91 //===----------------------------------------------------------------------===//
92 // Constant generators.
93 //
94 // All these functions are just wrappers to improve code legibility;
95 // therefore, we mark them as `inline` to avoid introducing any additional
96 // overhead due to the legibility.
97 //
98 // TODO: Ideally these should move upstream, so that we don't
99 // develop a design island.  However, doing so will involve
100 // substantial design work.  For related prior discussion, see
101 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879>
102 //===----------------------------------------------------------------------===//
103 
104 /// Generates a 0-valued constant of the given type.  In addition to
105 /// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also
106 /// works for `RankedTensorType` and `VectorType` (for which it generates
107 /// a constant `DenseElementsAttr` of zeros).
108 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
109   return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
110 }
111 
112 /// Generates a 1-valued constant of the given type.  This supports all
113 /// the same types as `constantZero`.
114 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
115   return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
116 }
117 
118 /// Generates a constant of `index` type.
119 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
120   return builder.create<arith::ConstantIndexOp>(loc, i);
121 }
122 
123 /// Generates a constant of `i32` type.
124 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
125   return builder.create<arith::ConstantIntOp>(loc, i, 32);
126 }
127 
128 /// Generates a constant of `i16` type.
129 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
130   return builder.create<arith::ConstantIntOp>(loc, i, 16);
131 }
132 
133 /// Generates a constant of `i8` type.
134 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
135   return builder.create<arith::ConstantIntOp>(loc, i, 8);
136 }
137 
138 /// Generates a constant of `i1` type.
139 inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
140   return builder.create<arith::ConstantIntOp>(loc, b, 1);
141 }
142 
143 /// Generates a constant of the given `Action`.
144 inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
145   return constantI32(builder, loc, static_cast<uint32_t>(action));
146 }
147 
148 /// Generates a constant of the internal type-encoding for overhead storage.
149 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
150                                           unsigned width) {
151   return constantI32(builder, loc,
152                      static_cast<uint32_t>(overheadTypeEncoding(width)));
153 }
154 
155 /// Generates a constant of the internal type-encoding for pointer
156 /// overhead storage.
157 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc,
158                                          const SparseTensorEncodingAttr &enc) {
159   return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth());
160 }
161 
162 /// Generates a constant of the internal type-encoding for index overhead
163 /// storage.
164 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc,
165                                        const SparseTensorEncodingAttr &enc) {
166   return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth());
167 }
168 
169 /// Generates a constant of the internal type-encoding for primary storage.
170 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
171                                          Type elemTp) {
172   return constantI32(builder, loc,
173                      static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
174 }
175 
176 /// Generates a constant of the internal dimension level type encoding.
177 inline Value
178 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
179                              SparseTensorEncodingAttr::DimLevelType dlt) {
180   return constantI8(builder, loc,
181                     static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
182 }
183 
184 } // namespace sparse_tensor
185 } // namespace mlir
186 
187 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
188