1*85b8d03eSwren romano //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2*85b8d03eSwren romano //
3*85b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*85b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information.
5*85b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*85b8d03eSwren romano //
7*85b8d03eSwren romano //===----------------------------------------------------------------------===//
8*85b8d03eSwren romano 
9*85b8d03eSwren romano #include "CodegenUtils.h"
10*85b8d03eSwren romano 
11*85b8d03eSwren romano #include "mlir/IR/Types.h"
12*85b8d03eSwren romano #include "mlir/IR/Value.h"
13*85b8d03eSwren romano 
14*85b8d03eSwren romano using namespace mlir;
15*85b8d03eSwren romano using namespace mlir::sparse_tensor;
16*85b8d03eSwren romano 
17*85b8d03eSwren romano //===----------------------------------------------------------------------===//
18*85b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions.
19*85b8d03eSwren romano //===----------------------------------------------------------------------===//
20*85b8d03eSwren romano 
21*85b8d03eSwren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
22*85b8d03eSwren romano   switch (width) {
23*85b8d03eSwren romano   default:
24*85b8d03eSwren romano     return OverheadType::kU64;
25*85b8d03eSwren romano   case 32:
26*85b8d03eSwren romano     return OverheadType::kU32;
27*85b8d03eSwren romano   case 16:
28*85b8d03eSwren romano     return OverheadType::kU16;
29*85b8d03eSwren romano   case 8:
30*85b8d03eSwren romano     return OverheadType::kU8;
31*85b8d03eSwren romano   }
32*85b8d03eSwren romano }
33*85b8d03eSwren romano 
34*85b8d03eSwren romano Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
35*85b8d03eSwren romano   switch (ot) {
36*85b8d03eSwren romano   case OverheadType::kU64:
37*85b8d03eSwren romano     return builder.getIntegerType(64);
38*85b8d03eSwren romano   case OverheadType::kU32:
39*85b8d03eSwren romano     return builder.getIntegerType(32);
40*85b8d03eSwren romano   case OverheadType::kU16:
41*85b8d03eSwren romano     return builder.getIntegerType(16);
42*85b8d03eSwren romano   case OverheadType::kU8:
43*85b8d03eSwren romano     return builder.getIntegerType(8);
44*85b8d03eSwren romano   }
45*85b8d03eSwren romano   llvm_unreachable("Unknown OverheadType");
46*85b8d03eSwren romano }
47*85b8d03eSwren romano 
48*85b8d03eSwren romano Type mlir::sparse_tensor::getPointerOverheadType(
49*85b8d03eSwren romano     Builder &builder, const SparseTensorEncodingAttr &enc) {
50*85b8d03eSwren romano   // NOTE(wrengr): This workaround will be fixed in D115010.
51*85b8d03eSwren romano   unsigned width = enc.getPointerBitWidth();
52*85b8d03eSwren romano   if (width == 0)
53*85b8d03eSwren romano     return builder.getIndexType();
54*85b8d03eSwren romano   return getOverheadType(builder, overheadTypeEncoding(width));
55*85b8d03eSwren romano }
56*85b8d03eSwren romano 
57*85b8d03eSwren romano Type mlir::sparse_tensor::getIndexOverheadType(
58*85b8d03eSwren romano     Builder &builder, const SparseTensorEncodingAttr &enc) {
59*85b8d03eSwren romano   // NOTE(wrengr): This workaround will be fixed in D115010.
60*85b8d03eSwren romano   unsigned width = enc.getIndexBitWidth();
61*85b8d03eSwren romano   if (width == 0)
62*85b8d03eSwren romano     return builder.getIndexType();
63*85b8d03eSwren romano   return getOverheadType(builder, overheadTypeEncoding(width));
64*85b8d03eSwren romano }
65*85b8d03eSwren romano 
66*85b8d03eSwren romano PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
67*85b8d03eSwren romano   if (elemTp.isF64())
68*85b8d03eSwren romano     return PrimaryType::kF64;
69*85b8d03eSwren romano   if (elemTp.isF32())
70*85b8d03eSwren romano     return PrimaryType::kF32;
71*85b8d03eSwren romano   if (elemTp.isInteger(64))
72*85b8d03eSwren romano     return PrimaryType::kI64;
73*85b8d03eSwren romano   if (elemTp.isInteger(32))
74*85b8d03eSwren romano     return PrimaryType::kI32;
75*85b8d03eSwren romano   if (elemTp.isInteger(16))
76*85b8d03eSwren romano     return PrimaryType::kI16;
77*85b8d03eSwren romano   if (elemTp.isInteger(8))
78*85b8d03eSwren romano     return PrimaryType::kI8;
79*85b8d03eSwren romano   llvm_unreachable("Unknown primary type");
80*85b8d03eSwren romano }
81*85b8d03eSwren romano 
82*85b8d03eSwren romano DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
83*85b8d03eSwren romano     SparseTensorEncodingAttr::DimLevelType dlt) {
84*85b8d03eSwren romano   switch (dlt) {
85*85b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Dense:
86*85b8d03eSwren romano     return DimLevelType::kDense;
87*85b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Compressed:
88*85b8d03eSwren romano     return DimLevelType::kCompressed;
89*85b8d03eSwren romano   case SparseTensorEncodingAttr::DimLevelType::Singleton:
90*85b8d03eSwren romano     return DimLevelType::kSingleton;
91*85b8d03eSwren romano   }
92*85b8d03eSwren romano   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
93*85b8d03eSwren romano }
94*85b8d03eSwren romano 
95*85b8d03eSwren romano //===----------------------------------------------------------------------===//
96*85b8d03eSwren romano // Misc code generators.
97*85b8d03eSwren romano //===----------------------------------------------------------------------===//
98*85b8d03eSwren romano 
99*85b8d03eSwren romano mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
100*85b8d03eSwren romano   if (tp.isa<FloatType>())
101*85b8d03eSwren romano     return builder.getFloatAttr(tp, 1.0);
102*85b8d03eSwren romano   if (tp.isa<IndexType>())
103*85b8d03eSwren romano     return builder.getIndexAttr(1);
104*85b8d03eSwren romano   if (auto intTp = tp.dyn_cast<IntegerType>())
105*85b8d03eSwren romano     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
106*85b8d03eSwren romano   if (tp.isa<RankedTensorType, VectorType>()) {
107*85b8d03eSwren romano     auto shapedTp = tp.cast<ShapedType>();
108*85b8d03eSwren romano     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
109*85b8d03eSwren romano       return DenseElementsAttr::get(shapedTp, one);
110*85b8d03eSwren romano   }
111*85b8d03eSwren romano   llvm_unreachable("Unsupported attribute type");
112*85b8d03eSwren romano }
113*85b8d03eSwren romano 
114*85b8d03eSwren romano Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
115*85b8d03eSwren romano                                         Value v) {
116*85b8d03eSwren romano   Type tp = v.getType();
117*85b8d03eSwren romano   Value zero = constantZero(builder, loc, tp);
118*85b8d03eSwren romano   if (tp.isa<FloatType>())
119*85b8d03eSwren romano     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
120*85b8d03eSwren romano                                          zero);
121*85b8d03eSwren romano   if (tp.isIntOrIndex())
122*85b8d03eSwren romano     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
123*85b8d03eSwren romano                                          zero);
124*85b8d03eSwren romano   llvm_unreachable("Non-numeric type");
125*85b8d03eSwren romano }
126