1 //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines utilities for generating MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
14 #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
15 
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18 #include "mlir/ExecutionEngine/SparseTensorUtils.h"
19 #include "mlir/IR/Builders.h"
20 
21 namespace mlir {
22 class Location;
23 class Type;
24 class Value;
25 
26 namespace sparse_tensor {
27 
28 //===----------------------------------------------------------------------===//
29 // ExecutionEngine/SparseTensorUtils helper functions.
30 //===----------------------------------------------------------------------===//
31 
32 /// Converts an overhead storage bitwidth to its internal type-encoding.
33 OverheadType overheadTypeEncoding(unsigned width);
34 
35 /// Converts the internal type-encoding for overhead storage to an mlir::Type.
36 Type getOverheadType(Builder &builder, OverheadType ot);
37 
38 /// Returns the mlir::Type for pointer overhead storage.
39 Type getPointerOverheadType(Builder &builder,
40                             const SparseTensorEncodingAttr &enc);
41 
42 /// Returns the mlir::Type for index overhead storage.
43 Type getIndexOverheadType(Builder &builder,
44                           const SparseTensorEncodingAttr &enc);
45 
46 /// Converts a primary storage type to its internal type-encoding.
47 PrimaryType primaryTypeEncoding(Type elemTp);
48 
49 /// Converts the IR's dimension level type to its internal type-encoding.
50 DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
51 
52 //===----------------------------------------------------------------------===//
53 // Misc code generators.
54 //
55 // TODO: both of these should move upstream to their respective classes.
56 // Once RFCs have been created for those changes, list them here.
57 //===----------------------------------------------------------------------===//
58 
59 /// Generates a 1-valued attribute of the given type.  This supports
60 /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
61 /// for unsupported types we raise `llvm_unreachable` rather than
62 /// returning a null attribute.
63 Attribute getOneAttr(Builder &builder, Type tp);
64 
65 /// Generates the comparison `v != 0` where `v` is of numeric type.
66 /// For floating types, we use the "unordered" comparator (i.e., returns
67 /// true if `v` is NaN).
68 Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
69 
70 //===----------------------------------------------------------------------===//
71 // Constant generators.
72 //
73 // All these functions are just wrappers to improve code legibility;
74 // therefore, we mark them as `inline` to avoid introducing any additional
75 // overhead due to the legibility.
76 //
77 // TODO: Ideally these should move upstream, so that we don't
78 // develop a design island.  However, doing so will involve
79 // substantial design work.  For related prior discussion, see
80 // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879>
81 //===----------------------------------------------------------------------===//
82 
83 /// Generates a 0-valued constant of the given type.  In addition to
84 /// the scalar types (`FloatType`, `IndexType`, `IntegerType`), this also
85 /// works for `RankedTensorType` and `VectorType` (for which it generates
86 /// a constant `DenseElementsAttr` of zeros).
87 inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
88   return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
89 }
90 
91 /// Generates a 1-valued constant of the given type.  This supports all
92 /// the same types as `constantZero`.
93 inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
94   return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
95 }
96 
97 /// Generates a constant of `index` type.
98 inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
99   return builder.create<arith::ConstantIndexOp>(loc, i);
100 }
101 
102 /// Generates a constant of `i32` type.
103 inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
104   return builder.create<arith::ConstantIntOp>(loc, i, 32);
105 }
106 
107 /// Generates a constant of `i16` type.
108 inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
109   return builder.create<arith::ConstantIntOp>(loc, i, 16);
110 }
111 
112 /// Generates a constant of `i8` type.
113 inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
114   return builder.create<arith::ConstantIntOp>(loc, i, 8);
115 }
116 
117 /// Generates a constant of `i1` type.
118 inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
119   return builder.create<arith::ConstantIntOp>(loc, b, 1);
120 }
121 
122 /// Generates a constant of the given `Action`.
123 inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
124   return constantI32(builder, loc, static_cast<uint32_t>(action));
125 }
126 
127 /// Generates a constant of the internal type-encoding for overhead storage.
128 inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
129                                           unsigned width) {
130   return constantI32(builder, loc,
131                      static_cast<uint32_t>(overheadTypeEncoding(width)));
132 }
133 
134 /// Generates a constant of the internal type-encoding for pointer
135 /// overhead storage.
136 inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc,
137                                          const SparseTensorEncodingAttr &enc) {
138   return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth());
139 }
140 
141 /// Generates a constant of the internal type-encoding for index overhead
142 /// storage.
143 inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc,
144                                        const SparseTensorEncodingAttr &enc) {
145   return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth());
146 }
147 
148 /// Generates a constant of the internal type-encoding for primary storage.
149 inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
150                                          Type elemTp) {
151   return constantI32(builder, loc,
152                      static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
153 }
154 
155 /// Generates a constant of the internal dimension level type encoding.
156 inline Value
157 constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
158                              SparseTensorEncodingAttr::DimLevelType dlt) {
159   return constantI8(builder, loc,
160                     static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
161 }
162 
163 } // namespace sparse_tensor
164 } // namespace mlir
165 
166 #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
167