1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for generating MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
15 
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
19 #include "mlir/IR/Builders.h"
20 
21 namespace mlir {
22 class Location;
23 class Type;
24 class Value;
25 
26 namespace sparse_tensor {
27 
28 //===----------------------------------------------------------------------===//
29 // ExecutionEngine/SparseTensorUtils helper functions.
30 //===----------------------------------------------------------------------===//
31 
32 /// Converts an overhead storage bitwidth to its internal type-encoding.
33 OverheadType overheadTypeEncoding(unsigned width);
34 
35 /// Converts an overhead storage type to its internal type-encoding.
36 OverheadType overheadTypeEncoding(Type tp);
37 
38 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
39 Type getOverheadType(Builder &builder, OverheadType ot);
40 
41 /// Returns the mlir::Type for pointer overhead storage.
42 Type getPointerOverheadType(Builder &builder,
43                             const SparseTensorEncodingAttr &enc);
44 
45 /// Returns the mlir::Type for index overhead storage.
46 Type getIndexOverheadType(Builder &builder,
47                           const SparseTensorEncodingAttr &enc);
48 
49 /// Convert OverheadType to its function-name suffix.
50 StringRef overheadTypeFunctionSuffix(OverheadType ot);
51 
52 /// Converts an overhead storage type to its function-name suffix.
53 StringRef overheadTypeFunctionSuffix(Type overheadTp);
54 
55 /// Converts a primary storage type to its internal type-encoding.
56 PrimaryType primaryTypeEncoding(Type elemTp);
57 
58 /// Convert PrimaryType to its function-name suffix.
59 StringRef primaryTypeFunctionSuffix(PrimaryType pt);
60 
61 /// Converts a primary storage type to its function-name suffix.
62 StringRef primaryTypeFunctionSuffix(Type elemTp);
63 
64 /// Converts the IR's dimension level type to its internal type-encoding.
65 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
66 
67 //===----------------------------------------------------------------------===//
68 // Misc code generators.
69 //
70 // TODO: both of these should move upstream to their respective classes.
71 // Once RFCs have been created for those changes, list them here.
72 //===----------------------------------------------------------------------===//
73 
74 /// Generates a 1-valued attribute of the given type.  This supports
75 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
76 /// for unsupported types we raise `llvm_unreachable` rather than
77 /// returning a null attribute.
78 Attribute getOneAttr(Builder &builder, Type tp);
79 
80 /// Generates the comparison `v != 0` where `v` is of numeric type.
81 /// For floating types, we use the "unordered" comparator (i.e., returns
82 /// true if `v` is NaN).
83 Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
84 
85 //===----------------------------------------------------------------------===//
86 // Constant generators.
87 //
88 // All these functions are just wrappers to improve code legibility;
89 // therefore, we mark them as `inline` to avoid introducing any additional
90 // overhead due to the legibility.
91 //
92 // TODO: Ideally these should move upstream, so that we don't
93 // develop a design island.  However, doing so will involve
94 // substantial design work.  For related prior discussion, see
95 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879>
96 //===----------------------------------------------------------------------===//
97 
98 /// Generates a 0-valued constant of the given type.  In addition to
99 /// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also
100 /// works for `RankedTensorType` and `VectorType` (for which it generates
101 /// a constant `DenseElementsAttr` of zeros).
102 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
103   return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
104 }
105 
106 /// Generates a 1-valued constant of the given type.  This supports all
107 /// the same types as `constantZero`.
108 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
109   return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
110 }
111 
112 /// Generates a constant of `index` type.
113 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
114   return builder.create<arith::ConstantIndexOp>(loc, i);
115 }
116 
117 /// Generates a constant of `i32` type.
118 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
119   return builder.create<arith::ConstantIntOp>(loc, i, 32);
120 }
121 
122 /// Generates a constant of `i16` type.
123 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
124   return builder.create<arith::ConstantIntOp>(loc, i, 16);
125 }
126 
127 /// Generates a constant of `i8` type.
128 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
129   return builder.create<arith::ConstantIntOp>(loc, i, 8);
130 }
131 
132 /// Generates a constant of `i1` type.
133 inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
134   return builder.create<arith::ConstantIntOp>(loc, b, 1);
135 }
136 
137 /// Generates a constant of the given `Action`.
138 inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
139   return constantI32(builder, loc, static_cast<uint32_t>(action));
140 }
141 
142 /// Generates a constant of the internal type-encoding for overhead storage.
143 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
144                                           unsigned width) {
145   return constantI32(builder, loc,
146                      static_cast<uint32_t>(overheadTypeEncoding(width)));
147 }
148 
149 /// Generates a constant of the internal type-encoding for pointer
150 /// overhead storage.
151 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc,
152                                          const SparseTensorEncodingAttr &enc) {
153   return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth());
154 }
155 
156 /// Generates a constant of the internal type-encoding for index overhead
157 /// storage.
158 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc,
159                                        const SparseTensorEncodingAttr &enc) {
160   return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth());
161 }
162 
163 /// Generates a constant of the internal type-encoding for primary storage.
164 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
165                                          Type elemTp) {
166   return constantI32(builder, loc,
167                      static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
168 }
169 
170 /// Generates a constant of the internal dimension level type encoding.
171 inline Value
172 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
173                              SparseTensorEncodingAttr::DimLevelType dlt) {
174   return constantI8(builder, loc,
175                     static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
176 }
177 
178 } // namespace sparse_tensor
179 } // namespace mlir
180 
181 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
182