185b8d03eSwren romano //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
285b8d03eSwren romano //
385b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
485b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information.
585b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
685b8d03eSwren romano //
785b8d03eSwren romano //===----------------------------------------------------------------------===//
885b8d03eSwren romano 
985b8d03eSwren romano #include "CodegenUtils.h"
1085b8d03eSwren romano 
1185b8d03eSwren romano #include "mlir/IR/Types.h"
1285b8d03eSwren romano #include "mlir/IR/Value.h"
1385b8d03eSwren romano 
1485b8d03eSwren romano using namespace mlir;
1585b8d03eSwren romano using namespace mlir::sparse_tensor;
1685b8d03eSwren romano 
1785b8d03eSwren romano //===----------------------------------------------------------------------===//
1885b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions.
1985b8d03eSwren romano //===----------------------------------------------------------------------===//
2085b8d03eSwren romano 
overheadTypeEncoding(unsigned width)2185b8d03eSwren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
2285b8d03eSwren romano   switch (width) {
23bc04a470Swren romano   case 64:
2485b8d03eSwren romano     return OverheadType::kU64;
2585b8d03eSwren romano   case 32:
2685b8d03eSwren romano     return OverheadType::kU32;
2785b8d03eSwren romano   case 16:
2885b8d03eSwren romano     return OverheadType::kU16;
2985b8d03eSwren romano   case 8:
3085b8d03eSwren romano     return OverheadType::kU8;
31bc04a470Swren romano   case 0:
32bc04a470Swren romano     return OverheadType::kIndex;
3385b8d03eSwren romano   }
34bc04a470Swren romano   llvm_unreachable("Unsupported overhead bitwidth");
3585b8d03eSwren romano }
3685b8d03eSwren romano 
overheadTypeEncoding(Type tp)37c9489225Swren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
38c9489225Swren romano   if (tp.isIndex())
39c9489225Swren romano     return OverheadType::kIndex;
40c9489225Swren romano   if (auto intTp = tp.dyn_cast<IntegerType>())
41c9489225Swren romano     return overheadTypeEncoding(intTp.getWidth());
42c9489225Swren romano   llvm_unreachable("Unknown overhead type");
43c9489225Swren romano }
44c9489225Swren romano 
getOverheadType(Builder & builder,OverheadType ot)4585b8d03eSwren romano Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
4685b8d03eSwren romano   switch (ot) {
47bc04a470Swren romano   case OverheadType::kIndex:
48bc04a470Swren romano     return builder.getIndexType();
4985b8d03eSwren romano   case OverheadType::kU64:
5085b8d03eSwren romano     return builder.getIntegerType(64);
5185b8d03eSwren romano   case OverheadType::kU32:
5285b8d03eSwren romano     return builder.getIntegerType(32);
5385b8d03eSwren romano   case OverheadType::kU16:
5485b8d03eSwren romano     return builder.getIntegerType(16);
5585b8d03eSwren romano   case OverheadType::kU8:
5685b8d03eSwren romano     return builder.getIntegerType(8);
5785b8d03eSwren romano   }
5885b8d03eSwren romano   llvm_unreachable("Unknown OverheadType");
5985b8d03eSwren romano }
6085b8d03eSwren romano 
pointerOverheadTypeEncoding(const SparseTensorEncodingAttr & enc)61ebc84664Swren romano OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding(
62ebc84664Swren romano     const SparseTensorEncodingAttr &enc) {
63ebc84664Swren romano   return overheadTypeEncoding(enc.getPointerBitWidth());
64ebc84664Swren romano }
65ebc84664Swren romano 
indexOverheadTypeEncoding(const SparseTensorEncodingAttr & enc)66ebc84664Swren romano OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding(
67ebc84664Swren romano     const SparseTensorEncodingAttr &enc) {
68ebc84664Swren romano   return overheadTypeEncoding(enc.getIndexBitWidth());
69ebc84664Swren romano }
70ebc84664Swren romano 
getPointerOverheadType(Builder & builder,const SparseTensorEncodingAttr & enc)7185b8d03eSwren romano Type mlir::sparse_tensor::getPointerOverheadType(
7285b8d03eSwren romano     Builder &builder, const SparseTensorEncodingAttr &enc) {
73ebc84664Swren romano   return getOverheadType(builder, pointerOverheadTypeEncoding(enc));
7485b8d03eSwren romano }
7585b8d03eSwren romano 
getIndexOverheadType(Builder & builder,const SparseTensorEncodingAttr & enc)7685b8d03eSwren romano Type mlir::sparse_tensor::getIndexOverheadType(
7785b8d03eSwren romano     Builder &builder, const SparseTensorEncodingAttr &enc) {
78ebc84664Swren romano   return getOverheadType(builder, indexOverheadTypeEncoding(enc));
7985b8d03eSwren romano }
8085b8d03eSwren romano 
81b364c766Swren romano // TODO: Adjust the naming convention for the constructors of
82b364c766Swren romano // `OverheadType` so we can use the `FOREVERY_O` x-macro here instead
83b364c766Swren romano // of `FOREVERY_FIXED_O`; to further reduce the possibility of typo bugs
84b364c766Swren romano // or things getting out of sync.
overheadTypeFunctionSuffix(OverheadType ot)85c9489225Swren romano StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
86c9489225Swren romano   switch (ot) {
87c9489225Swren romano   case OverheadType::kIndex:
88b364c766Swren romano     return "0";
8998e142cdSwren romano #define CASE(ONAME, O)                                                         \
9098e142cdSwren romano   case OverheadType::kU##ONAME:                                                \
9198e142cdSwren romano     return #ONAME;
9298e142cdSwren romano     FOREVERY_FIXED_O(CASE)
9398e142cdSwren romano #undef CASE
94c9489225Swren romano   }
95c9489225Swren romano   llvm_unreachable("Unknown OverheadType");
96c9489225Swren romano }
97c9489225Swren romano 
overheadTypeFunctionSuffix(Type tp)98c9489225Swren romano StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
99c9489225Swren romano   return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
100c9489225Swren romano }
101c9489225Swren romano 
primaryTypeEncoding(Type elemTp)10285b8d03eSwren romano PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
10385b8d03eSwren romano   if (elemTp.isF64())
10485b8d03eSwren romano     return PrimaryType::kF64;
10585b8d03eSwren romano   if (elemTp.isF32())
10685b8d03eSwren romano     return PrimaryType::kF32;
107*ea8ed5cbSbixia1   if (elemTp.isF16())
108*ea8ed5cbSbixia1     return PrimaryType::kF16;
109*ea8ed5cbSbixia1   if (elemTp.isBF16())
110*ea8ed5cbSbixia1     return PrimaryType::kBF16;
11185b8d03eSwren romano   if (elemTp.isInteger(64))
11285b8d03eSwren romano     return PrimaryType::kI64;
11385b8d03eSwren romano   if (elemTp.isInteger(32))
11485b8d03eSwren romano     return PrimaryType::kI32;
11585b8d03eSwren romano   if (elemTp.isInteger(16))
11685b8d03eSwren romano     return PrimaryType::kI16;
11785b8d03eSwren romano   if (elemTp.isInteger(8))
11885b8d03eSwren romano     return PrimaryType::kI8;
119736c1b66SAart Bik   if (auto complexTp = elemTp.dyn_cast<ComplexType>()) {
120736c1b66SAart Bik     auto complexEltTp = complexTp.getElementType();
121736c1b66SAart Bik     if (complexEltTp.isF64())
122736c1b66SAart Bik       return PrimaryType::kC64;
123736c1b66SAart Bik     if (complexEltTp.isF32())
124736c1b66SAart Bik       return PrimaryType::kC32;
125736c1b66SAart Bik   }
12685b8d03eSwren romano   llvm_unreachable("Unknown primary type");
12785b8d03eSwren romano }
12885b8d03eSwren romano 
primaryTypeFunctionSuffix(PrimaryType pt)129c9489225Swren romano StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
130c9489225Swren romano   switch (pt) {
13198e142cdSwren romano #define CASE(VNAME, V)                                                         \
13298e142cdSwren romano   case PrimaryType::k##VNAME:                                                  \
13398e142cdSwren romano     return #VNAME;
13498e142cdSwren romano     FOREVERY_V(CASE)
13598e142cdSwren romano #undef CASE
136c9489225Swren romano   }
137c9489225Swren romano   llvm_unreachable("Unknown PrimaryType");
138c9489225Swren romano }
139c9489225Swren romano 
primaryTypeFunctionSuffix(Type elemTp)140c9489225Swren romano StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
141c9489225Swren romano   return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
142c9489225Swren romano }
143c9489225Swren romano 
dimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt)14485b8d03eSwren romano DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
14585b8d03eSwren romano     SparseTensorEncodingAttr::DimLevelType dlt) {
14685b8d03eSwren romano   switch (dlt) {
14785b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Dense:
14885b8d03eSwren romano     return DimLevelType::kDense;
14985b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Compressed:
15085b8d03eSwren romano     return DimLevelType::kCompressed;
15185b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Singleton:
15285b8d03eSwren romano     return DimLevelType::kSingleton;
15385b8d03eSwren romano   }
15485b8d03eSwren romano   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
15585b8d03eSwren romano }
15685b8d03eSwren romano 
15785b8d03eSwren romano //===----------------------------------------------------------------------===//
15885b8d03eSwren romano // Misc code generators.
15985b8d03eSwren romano //===----------------------------------------------------------------------===//
16085b8d03eSwren romano 
getOneAttr(Builder & builder,Type tp)16185b8d03eSwren romano mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
16285b8d03eSwren romano   if (tp.isa<FloatType>())
16385b8d03eSwren romano     return builder.getFloatAttr(tp, 1.0);
16485b8d03eSwren romano   if (tp.isa<IndexType>())
16585b8d03eSwren romano     return builder.getIndexAttr(1);
16685b8d03eSwren romano   if (auto intTp = tp.dyn_cast<IntegerType>())
16785b8d03eSwren romano     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
16885b8d03eSwren romano   if (tp.isa<RankedTensorType, VectorType>()) {
16985b8d03eSwren romano     auto shapedTp = tp.cast<ShapedType>();
17085b8d03eSwren romano     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
17185b8d03eSwren romano       return DenseElementsAttr::get(shapedTp, one);
17285b8d03eSwren romano   }
17385b8d03eSwren romano   llvm_unreachable("Unsupported attribute type");
17485b8d03eSwren romano }
17585b8d03eSwren romano 
genIsNonzero(OpBuilder & builder,mlir::Location loc,Value v)17685b8d03eSwren romano Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
17785b8d03eSwren romano                                         Value v) {
17885b8d03eSwren romano   Type tp = v.getType();
17985b8d03eSwren romano   Value zero = constantZero(builder, loc, tp);
18085b8d03eSwren romano   if (tp.isa<FloatType>())
18185b8d03eSwren romano     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
18285b8d03eSwren romano                                          zero);
18385b8d03eSwren romano   if (tp.isIntOrIndex())
18485b8d03eSwren romano     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
18585b8d03eSwren romano                                          zero);
18628b6d412SAart Bik   if (tp.dyn_cast<ComplexType>())
18728b6d412SAart Bik     return builder.create<complex::NotEqualOp>(loc, v, zero);
18885b8d03eSwren romano   llvm_unreachable("Non-numeric type");
18985b8d03eSwren romano }
190