185b8d03eSwren romano //===- CodegenUtils.h - Utilities for generating MLIR -----------*- C++ -*-===//
285b8d03eSwren romano //
385b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
485b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information.
585b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
685b8d03eSwren romano //
785b8d03eSwren romano //===----------------------------------------------------------------------===//
885b8d03eSwren romano //
985b8d03eSwren romano // This header file defines utilities for generating MLIR.
1085b8d03eSwren romano //
1185b8d03eSwren romano //===----------------------------------------------------------------------===//
1285b8d03eSwren romano 
1385b8d03eSwren romano #ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
1485b8d03eSwren romano #define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
1585b8d03eSwren romano 
1685b8d03eSwren romano #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17*28b6d412SAart Bik #include "mlir/Dialect/Complex/IR/Complex.h"
1885b8d03eSwren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1985b8d03eSwren romano #include "mlir/ExecutionEngine/SparseTensorUtils.h"
2085b8d03eSwren romano #include "mlir/IR/Builders.h"
2185b8d03eSwren romano 
2285b8d03eSwren romano namespace mlir {
2385b8d03eSwren romano class Location;
2485b8d03eSwren romano class Type;
2585b8d03eSwren romano class Value;
2685b8d03eSwren romano 
2785b8d03eSwren romano namespace sparse_tensor {
2885b8d03eSwren romano 
2985b8d03eSwren romano //===----------------------------------------------------------------------===//
3085b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions.
3185b8d03eSwren romano //===----------------------------------------------------------------------===//
3285b8d03eSwren romano 
3385b8d03eSwren romano /// Converts an overhead storage bitwidth to its internal type-encoding.
3485b8d03eSwren romano OverheadType overheadTypeEncoding(unsigned width);
3585b8d03eSwren romano 
36c9489225Swren romano /// Converts an overhead storage type to its internal type-encoding.
37c9489225Swren romano OverheadType overheadTypeEncoding(Type tp);
38c9489225Swren romano 
3985b8d03eSwren romano /// Converts the internal type-encoding for overhead storage to an mlir::Type.
4085b8d03eSwren romano Type getOverheadType(Builder &builder, OverheadType ot);
4185b8d03eSwren romano 
42ebc84664Swren romano /// Returns the OverheadType for pointer overhead storage.
43ebc84664Swren romano OverheadType pointerOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
44ebc84664Swren romano 
45ebc84664Swren romano /// Returns the OverheadType for index overhead storage.
46ebc84664Swren romano OverheadType indexOverheadTypeEncoding(const SparseTensorEncodingAttr &enc);
47ebc84664Swren romano 
4885b8d03eSwren romano /// Returns the mlir::Type for pointer overhead storage.
4985b8d03eSwren romano Type getPointerOverheadType(Builder &builder,
5085b8d03eSwren romano                             const SparseTensorEncodingAttr &enc);
5185b8d03eSwren romano 
5285b8d03eSwren romano /// Returns the mlir::Type for index overhead storage.
5385b8d03eSwren romano Type getIndexOverheadType(Builder &builder,
5485b8d03eSwren romano                           const SparseTensorEncodingAttr &enc);
5585b8d03eSwren romano 
56c9489225Swren romano /// Convert OverheadType to its function-name suffix.
57c9489225Swren romano StringRef overheadTypeFunctionSuffix(OverheadType ot);
58c9489225Swren romano 
59c9489225Swren romano /// Converts an overhead storage type to its function-name suffix.
60c9489225Swren romano StringRef overheadTypeFunctionSuffix(Type overheadTp);
61c9489225Swren romano 
6285b8d03eSwren romano /// Converts a primary storage type to its internal type-encoding.
6385b8d03eSwren romano PrimaryType primaryTypeEncoding(Type elemTp);
6485b8d03eSwren romano 
65c9489225Swren romano /// Convert PrimaryType to its function-name suffix.
66c9489225Swren romano StringRef primaryTypeFunctionSuffix(PrimaryType pt);
67c9489225Swren romano 
68c9489225Swren romano /// Converts a primary storage type to its function-name suffix.
69c9489225Swren romano StringRef primaryTypeFunctionSuffix(Type elemTp);
70c9489225Swren romano 
7185b8d03eSwren romano /// Converts the IR's dimension level type to its internal type-encoding.
7285b8d03eSwren romano DimLevelType dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt);
7385b8d03eSwren romano 
7485b8d03eSwren romano //===----------------------------------------------------------------------===//
7585b8d03eSwren romano // Misc code generators.
7685b8d03eSwren romano //
7785b8d03eSwren romano // TODO: both of these should move upstream to their respective classes.
7885b8d03eSwren romano // Once RFCs have been created for those changes, list them here.
7985b8d03eSwren romano //===----------------------------------------------------------------------===//
8085b8d03eSwren romano 
8185b8d03eSwren romano /// Generates a 1-valued attribute of the given type.  This supports
8285b8d03eSwren romano /// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
8385b8d03eSwren romano /// for unsupported types we raise `llvm_unreachable` rather than
8485b8d03eSwren romano /// returning a null attribute.
8585b8d03eSwren romano Attribute getOneAttr(Builder &builder, Type tp);
8685b8d03eSwren romano 
8785b8d03eSwren romano /// Generates the comparison `v != 0` where `v` is of numeric type.
8885b8d03eSwren romano /// For floating types, we use the "unordered" comparator (i.e., returns
8985b8d03eSwren romano /// true if `v` is NaN).
9085b8d03eSwren romano Value genIsNonzero(OpBuilder &builder, Location loc, Value v);
9185b8d03eSwren romano 
9285b8d03eSwren romano //===----------------------------------------------------------------------===//
9385b8d03eSwren romano // Constant generators.
9485b8d03eSwren romano //
9585b8d03eSwren romano // All these functions are just wrappers to improve code legibility;
9685b8d03eSwren romano // therefore, we mark them as `inline` to avoid introducing any additional
9785b8d03eSwren romano // overhead due to the legibility.
9885b8d03eSwren romano //
9985b8d03eSwren romano // TODO: Ideally these should move upstream, so that we don't
10085b8d03eSwren romano // develop a design island.  However, doing so will involve
10185b8d03eSwren romano // substantial design work.  For related prior discussion, see
10285b8d03eSwren romano // <https://llvm.discourse.group/t/evolving-builder-apis-based-on-lessons-learned-from-edsc/879>
10385b8d03eSwren romano //===----------------------------------------------------------------------===//
10485b8d03eSwren romano 
10585b8d03eSwren romano /// Generates a 0-valued constant of the given type.  In addition to
106*28b6d412SAart Bik /// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`),
107*28b6d412SAart Bik /// this also works for `RankedTensorType` and `VectorType` (for which it
108*28b6d412SAart Bik /// generates a constant `DenseElementsAttr` of zeros).
constantZero(OpBuilder & builder,Location loc,Type tp)10985b8d03eSwren romano inline Value constantZero(OpBuilder &builder, Location loc, Type tp) {
110*28b6d412SAart Bik   if (auto ctp = tp.dyn_cast<ComplexType>()) {
111*28b6d412SAart Bik     auto zeroe = builder.getZeroAttr(ctp.getElementType());
112*28b6d412SAart Bik     auto zeroa = builder.getArrayAttr({zeroe, zeroe});
113*28b6d412SAart Bik     return builder.create<complex::ConstantOp>(loc, tp, zeroa);
114*28b6d412SAart Bik   }
11585b8d03eSwren romano   return builder.create<arith::ConstantOp>(loc, tp, builder.getZeroAttr(tp));
11685b8d03eSwren romano }
11785b8d03eSwren romano 
11885b8d03eSwren romano /// Generates a 1-valued constant of the given type.  This supports all
11985b8d03eSwren romano /// the same types as `constantZero`.
constantOne(OpBuilder & builder,Location loc,Type tp)12085b8d03eSwren romano inline Value constantOne(OpBuilder &builder, Location loc, Type tp) {
121*28b6d412SAart Bik   if (auto ctp = tp.dyn_cast<ComplexType>()) {
122*28b6d412SAart Bik     auto zeroe = builder.getZeroAttr(ctp.getElementType());
123*28b6d412SAart Bik     auto onee = getOneAttr(builder, ctp.getElementType());
124*28b6d412SAart Bik     auto zeroa = builder.getArrayAttr({onee, zeroe});
125*28b6d412SAart Bik     return builder.create<complex::ConstantOp>(loc, tp, zeroa);
126*28b6d412SAart Bik   }
12785b8d03eSwren romano   return builder.create<arith::ConstantOp>(loc, tp, getOneAttr(builder, tp));
12885b8d03eSwren romano }
12985b8d03eSwren romano 
13085b8d03eSwren romano /// Generates a constant of `index` type.
constantIndex(OpBuilder & builder,Location loc,int64_t i)13185b8d03eSwren romano inline Value constantIndex(OpBuilder &builder, Location loc, int64_t i) {
13285b8d03eSwren romano   return builder.create<arith::ConstantIndexOp>(loc, i);
13385b8d03eSwren romano }
13485b8d03eSwren romano 
13585b8d03eSwren romano /// Generates a constant of `i32` type.
constantI32(OpBuilder & builder,Location loc,int32_t i)13685b8d03eSwren romano inline Value constantI32(OpBuilder &builder, Location loc, int32_t i) {
13785b8d03eSwren romano   return builder.create<arith::ConstantIntOp>(loc, i, 32);
13885b8d03eSwren romano }
13985b8d03eSwren romano 
14085b8d03eSwren romano /// Generates a constant of `i16` type.
constantI16(OpBuilder & builder,Location loc,int16_t i)14185b8d03eSwren romano inline Value constantI16(OpBuilder &builder, Location loc, int16_t i) {
14285b8d03eSwren romano   return builder.create<arith::ConstantIntOp>(loc, i, 16);
14385b8d03eSwren romano }
14485b8d03eSwren romano 
14585b8d03eSwren romano /// Generates a constant of `i8` type.
constantI8(OpBuilder & builder,Location loc,int8_t i)14685b8d03eSwren romano inline Value constantI8(OpBuilder &builder, Location loc, int8_t i) {
14785b8d03eSwren romano   return builder.create<arith::ConstantIntOp>(loc, i, 8);
14885b8d03eSwren romano }
14985b8d03eSwren romano 
15085b8d03eSwren romano /// Generates a constant of `i1` type.
constantI1(OpBuilder & builder,Location loc,bool b)15185b8d03eSwren romano inline Value constantI1(OpBuilder &builder, Location loc, bool b) {
15285b8d03eSwren romano   return builder.create<arith::ConstantIntOp>(loc, b, 1);
15385b8d03eSwren romano }
15485b8d03eSwren romano 
15585b8d03eSwren romano /// Generates a constant of the given `Action`.
constantAction(OpBuilder & builder,Location loc,Action action)15685b8d03eSwren romano inline Value constantAction(OpBuilder &builder, Location loc, Action action) {
15785b8d03eSwren romano   return constantI32(builder, loc, static_cast<uint32_t>(action));
15885b8d03eSwren romano }
15985b8d03eSwren romano 
16085b8d03eSwren romano /// Generates a constant of the internal type-encoding for overhead storage.
constantOverheadTypeEncoding(OpBuilder & builder,Location loc,unsigned width)16185b8d03eSwren romano inline Value constantOverheadTypeEncoding(OpBuilder &builder, Location loc,
16285b8d03eSwren romano                                           unsigned width) {
16385b8d03eSwren romano   return constantI32(builder, loc,
16485b8d03eSwren romano                      static_cast<uint32_t>(overheadTypeEncoding(width)));
16585b8d03eSwren romano }
16685b8d03eSwren romano 
16785b8d03eSwren romano /// Generates a constant of the internal type-encoding for pointer
16885b8d03eSwren romano /// overhead storage.
constantPointerTypeEncoding(OpBuilder & builder,Location loc,const SparseTensorEncodingAttr & enc)16985b8d03eSwren romano inline Value constantPointerTypeEncoding(OpBuilder &builder, Location loc,
17085b8d03eSwren romano                                          const SparseTensorEncodingAttr &enc) {
17185b8d03eSwren romano   return constantOverheadTypeEncoding(builder, loc, enc.getPointerBitWidth());
17285b8d03eSwren romano }
17385b8d03eSwren romano 
17485b8d03eSwren romano /// Generates a constant of the internal type-encoding for index overhead
17585b8d03eSwren romano /// storage.
constantIndexTypeEncoding(OpBuilder & builder,Location loc,const SparseTensorEncodingAttr & enc)17685b8d03eSwren romano inline Value constantIndexTypeEncoding(OpBuilder &builder, Location loc,
17785b8d03eSwren romano                                        const SparseTensorEncodingAttr &enc) {
17885b8d03eSwren romano   return constantOverheadTypeEncoding(builder, loc, enc.getIndexBitWidth());
17985b8d03eSwren romano }
18085b8d03eSwren romano 
18185b8d03eSwren romano /// Generates a constant of the internal type-encoding for primary storage.
constantPrimaryTypeEncoding(OpBuilder & builder,Location loc,Type elemTp)18285b8d03eSwren romano inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
18385b8d03eSwren romano                                          Type elemTp) {
18485b8d03eSwren romano   return constantI32(builder, loc,
18585b8d03eSwren romano                      static_cast<uint32_t>(primaryTypeEncoding(elemTp)));
18685b8d03eSwren romano }
18785b8d03eSwren romano 
18885b8d03eSwren romano /// Generates a constant of the internal dimension level type encoding.
18985b8d03eSwren romano inline Value
constantDimLevelTypeEncoding(OpBuilder & builder,Location loc,SparseTensorEncodingAttr::DimLevelType dlt)19085b8d03eSwren romano constantDimLevelTypeEncoding(OpBuilder &builder, Location loc,
19185b8d03eSwren romano                              SparseTensorEncodingAttr::DimLevelType dlt) {
19285b8d03eSwren romano   return constantI8(builder, loc,
19385b8d03eSwren romano                     static_cast<uint8_t>(dimLevelTypeEncoding(dlt)));
19485b8d03eSwren romano }
19585b8d03eSwren romano 
19685b8d03eSwren romano } // namespace sparse_tensor
19785b8d03eSwren romano } // namespace mlir
19885b8d03eSwren romano 
19985b8d03eSwren romano #endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_CODEGENUTILS_H_
200