185b8d03eSwren romano //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
285b8d03eSwren romano //
385b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
485b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information.
585b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
685b8d03eSwren romano //
785b8d03eSwren romano //===----------------------------------------------------------------------===//
885b8d03eSwren romano 
985b8d03eSwren romano #include "CodegenUtils.h"
1085b8d03eSwren romano 
1185b8d03eSwren romano #include "mlir/IR/Types.h"
1285b8d03eSwren romano #include "mlir/IR/Value.h"
1385b8d03eSwren romano 
1485b8d03eSwren romano using namespace mlir;
1585b8d03eSwren romano using namespace mlir::sparse_tensor;
1685b8d03eSwren romano 
1785b8d03eSwren romano //===----------------------------------------------------------------------===//
1885b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions.
1985b8d03eSwren romano //===----------------------------------------------------------------------===//
2085b8d03eSwren romano 
2185b8d03eSwren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
2285b8d03eSwren romano   switch (width) {
23bc04a470Swren romano   case 64:
2485b8d03eSwren romano     return OverheadType::kU64;
2585b8d03eSwren romano   case 32:
2685b8d03eSwren romano     return OverheadType::kU32;
2785b8d03eSwren romano   case 16:
2885b8d03eSwren romano     return OverheadType::kU16;
2985b8d03eSwren romano   case 8:
3085b8d03eSwren romano     return OverheadType::kU8;
31bc04a470Swren romano   case 0:
32bc04a470Swren romano     return OverheadType::kIndex;
3385b8d03eSwren romano   }
34bc04a470Swren romano   llvm_unreachable("Unsupported overhead bitwidth");
3585b8d03eSwren romano }
3685b8d03eSwren romano 
37*c9489225Swren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
38*c9489225Swren romano   if (tp.isIndex())
39*c9489225Swren romano     return OverheadType::kIndex;
40*c9489225Swren romano   if (auto intTp = tp.dyn_cast<IntegerType>())
41*c9489225Swren romano     return overheadTypeEncoding(intTp.getWidth());
42*c9489225Swren romano   llvm_unreachable("Unknown overhead type");
43*c9489225Swren romano }
44*c9489225Swren romano 
4585b8d03eSwren romano Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
4685b8d03eSwren romano   switch (ot) {
47bc04a470Swren romano   case OverheadType::kIndex:
48bc04a470Swren romano     return builder.getIndexType();
4985b8d03eSwren romano   case OverheadType::kU64:
5085b8d03eSwren romano     return builder.getIntegerType(64);
5185b8d03eSwren romano   case OverheadType::kU32:
5285b8d03eSwren romano     return builder.getIntegerType(32);
5385b8d03eSwren romano   case OverheadType::kU16:
5485b8d03eSwren romano     return builder.getIntegerType(16);
5585b8d03eSwren romano   case OverheadType::kU8:
5685b8d03eSwren romano     return builder.getIntegerType(8);
5785b8d03eSwren romano   }
5885b8d03eSwren romano   llvm_unreachable("Unknown OverheadType");
5985b8d03eSwren romano }
6085b8d03eSwren romano 
6185b8d03eSwren romano Type mlir::sparse_tensor::getPointerOverheadType(
6285b8d03eSwren romano     Builder &builder, const SparseTensorEncodingAttr &enc) {
63bc04a470Swren romano   return getOverheadType(builder,
64bc04a470Swren romano                          overheadTypeEncoding(enc.getPointerBitWidth()));
6585b8d03eSwren romano }
6685b8d03eSwren romano 
6785b8d03eSwren romano Type mlir::sparse_tensor::getIndexOverheadType(
6885b8d03eSwren romano     Builder &builder, const SparseTensorEncodingAttr &enc) {
69bc04a470Swren romano   return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
7085b8d03eSwren romano }
7185b8d03eSwren romano 
72*c9489225Swren romano StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
73*c9489225Swren romano   switch (ot) {
74*c9489225Swren romano   case OverheadType::kIndex:
75*c9489225Swren romano     return "";
76*c9489225Swren romano   case OverheadType::kU64:
77*c9489225Swren romano     return "64";
78*c9489225Swren romano   case OverheadType::kU32:
79*c9489225Swren romano     return "32";
80*c9489225Swren romano   case OverheadType::kU16:
81*c9489225Swren romano     return "16";
82*c9489225Swren romano   case OverheadType::kU8:
83*c9489225Swren romano     return "8";
84*c9489225Swren romano   }
85*c9489225Swren romano   llvm_unreachable("Unknown OverheadType");
86*c9489225Swren romano }
87*c9489225Swren romano 
88*c9489225Swren romano StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
89*c9489225Swren romano   return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
90*c9489225Swren romano }
91*c9489225Swren romano 
9285b8d03eSwren romano PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
9385b8d03eSwren romano   if (elemTp.isF64())
9485b8d03eSwren romano     return PrimaryType::kF64;
9585b8d03eSwren romano   if (elemTp.isF32())
9685b8d03eSwren romano     return PrimaryType::kF32;
9785b8d03eSwren romano   if (elemTp.isInteger(64))
9885b8d03eSwren romano     return PrimaryType::kI64;
9985b8d03eSwren romano   if (elemTp.isInteger(32))
10085b8d03eSwren romano     return PrimaryType::kI32;
10185b8d03eSwren romano   if (elemTp.isInteger(16))
10285b8d03eSwren romano     return PrimaryType::kI16;
10385b8d03eSwren romano   if (elemTp.isInteger(8))
10485b8d03eSwren romano     return PrimaryType::kI8;
10585b8d03eSwren romano   llvm_unreachable("Unknown primary type");
10685b8d03eSwren romano }
10785b8d03eSwren romano 
108*c9489225Swren romano StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
109*c9489225Swren romano   switch (pt) {
110*c9489225Swren romano   case PrimaryType::kF64:
111*c9489225Swren romano     return "F64";
112*c9489225Swren romano   case PrimaryType::kF32:
113*c9489225Swren romano     return "F32";
114*c9489225Swren romano   case PrimaryType::kI64:
115*c9489225Swren romano     return "I64";
116*c9489225Swren romano   case PrimaryType::kI32:
117*c9489225Swren romano     return "I32";
118*c9489225Swren romano   case PrimaryType::kI16:
119*c9489225Swren romano     return "I16";
120*c9489225Swren romano   case PrimaryType::kI8:
121*c9489225Swren romano     return "I8";
122*c9489225Swren romano   }
123*c9489225Swren romano   llvm_unreachable("Unknown PrimaryType");
124*c9489225Swren romano }
125*c9489225Swren romano 
126*c9489225Swren romano StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
127*c9489225Swren romano   return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
128*c9489225Swren romano }
129*c9489225Swren romano 
13085b8d03eSwren romano DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
13185b8d03eSwren romano     SparseTensorEncodingAttr::DimLevelType dlt) {
13285b8d03eSwren romano   switch (dlt) {
13385b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Dense:
13485b8d03eSwren romano     return DimLevelType::kDense;
13585b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Compressed:
13685b8d03eSwren romano     return DimLevelType::kCompressed;
13785b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Singleton:
13885b8d03eSwren romano     return DimLevelType::kSingleton;
13985b8d03eSwren romano   }
14085b8d03eSwren romano   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
14185b8d03eSwren romano }
14285b8d03eSwren romano 
14385b8d03eSwren romano //===----------------------------------------------------------------------===//
14485b8d03eSwren romano // Misc code generators.
14585b8d03eSwren romano //===----------------------------------------------------------------------===//
14685b8d03eSwren romano 
14785b8d03eSwren romano mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
14885b8d03eSwren romano   if (tp.isa<FloatType>())
14985b8d03eSwren romano     return builder.getFloatAttr(tp, 1.0);
15085b8d03eSwren romano   if (tp.isa<IndexType>())
15185b8d03eSwren romano     return builder.getIndexAttr(1);
15285b8d03eSwren romano   if (auto intTp = tp.dyn_cast<IntegerType>())
15385b8d03eSwren romano     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
15485b8d03eSwren romano   if (tp.isa<RankedTensorType, VectorType>()) {
15585b8d03eSwren romano     auto shapedTp = tp.cast<ShapedType>();
15685b8d03eSwren romano     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
15785b8d03eSwren romano       return DenseElementsAttr::get(shapedTp, one);
15885b8d03eSwren romano   }
15985b8d03eSwren romano   llvm_unreachable("Unsupported attribute type");
16085b8d03eSwren romano }
16185b8d03eSwren romano 
16285b8d03eSwren romano Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
16385b8d03eSwren romano                                         Value v) {
16485b8d03eSwren romano   Type tp = v.getType();
16585b8d03eSwren romano   Value zero = constantZero(builder, loc, tp);
16685b8d03eSwren romano   if (tp.isa<FloatType>())
16785b8d03eSwren romano     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
16885b8d03eSwren romano                                          zero);
16985b8d03eSwren romano   if (tp.isIntOrIndex())
17085b8d03eSwren romano     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
17185b8d03eSwren romano                                          zero);
17285b8d03eSwren romano   llvm_unreachable("Non-numeric type");
17385b8d03eSwren romano }
174