1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 
11 #include "mlir/IR/Types.h"
12 #include "mlir/IR/Value.h"
13 
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
16 
17 //===----------------------------------------------------------------------===//
18 // ExecutionEngine/SparseTensorUtils helper functions.
19 //===----------------------------------------------------------------------===//
20 
21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
22   switch (width) {
23   case 64:
24     return OverheadType::kU64;
25   case 32:
26     return OverheadType::kU32;
27   case 16:
28     return OverheadType::kU16;
29   case 8:
30     return OverheadType::kU8;
31   case 0:
32     return OverheadType::kIndex;
33   }
34   llvm_unreachable("Unsupported overhead bitwidth");
35 }
36 
37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
38   if (tp.isIndex())
39     return OverheadType::kIndex;
40   if (auto intTp = tp.dyn_cast<IntegerType>())
41     return overheadTypeEncoding(intTp.getWidth());
42   llvm_unreachable("Unknown overhead type");
43 }
44 
45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
46   switch (ot) {
47   case OverheadType::kIndex:
48     return builder.getIndexType();
49   case OverheadType::kU64:
50     return builder.getIntegerType(64);
51   case OverheadType::kU32:
52     return builder.getIntegerType(32);
53   case OverheadType::kU16:
54     return builder.getIntegerType(16);
55   case OverheadType::kU8:
56     return builder.getIntegerType(8);
57   }
58   llvm_unreachable("Unknown OverheadType");
59 }
60 
61 Type mlir::sparse_tensor::getPointerOverheadType(
62     Builder &builder, const SparseTensorEncodingAttr &enc) {
63   return getOverheadType(builder,
64                          overheadTypeEncoding(enc.getPointerBitWidth()));
65 }
66 
67 Type mlir::sparse_tensor::getIndexOverheadType(
68     Builder &builder, const SparseTensorEncodingAttr &enc) {
69   return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth()));
70 }
71 
72 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
73   switch (ot) {
74   case OverheadType::kIndex:
75     return "";
76   case OverheadType::kU64:
77     return "64";
78   case OverheadType::kU32:
79     return "32";
80   case OverheadType::kU16:
81     return "16";
82   case OverheadType::kU8:
83     return "8";
84   }
85   llvm_unreachable("Unknown OverheadType");
86 }
87 
88 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
89   return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
90 }
91 
92 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
93   if (elemTp.isF64())
94     return PrimaryType::kF64;
95   if (elemTp.isF32())
96     return PrimaryType::kF32;
97   if (elemTp.isInteger(64))
98     return PrimaryType::kI64;
99   if (elemTp.isInteger(32))
100     return PrimaryType::kI32;
101   if (elemTp.isInteger(16))
102     return PrimaryType::kI16;
103   if (elemTp.isInteger(8))
104     return PrimaryType::kI8;
105   llvm_unreachable("Unknown primary type");
106 }
107 
108 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
109   switch (pt) {
110   case PrimaryType::kF64:
111     return "F64";
112   case PrimaryType::kF32:
113     return "F32";
114   case PrimaryType::kI64:
115     return "I64";
116   case PrimaryType::kI32:
117     return "I32";
118   case PrimaryType::kI16:
119     return "I16";
120   case PrimaryType::kI8:
121     return "I8";
122   }
123   llvm_unreachable("Unknown PrimaryType");
124 }
125 
126 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
127   return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
128 }
129 
130 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
131     SparseTensorEncodingAttr::DimLevelType dlt) {
132   switch (dlt) {
133   case SparseTensorEncodingAttr::DimLevelType::Dense:
134     return DimLevelType::kDense;
135   case SparseTensorEncodingAttr::DimLevelType::Compressed:
136     return DimLevelType::kCompressed;
137   case SparseTensorEncodingAttr::DimLevelType::Singleton:
138     return DimLevelType::kSingleton;
139   }
140   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Misc code generators.
145 //===----------------------------------------------------------------------===//
146 
147 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
148   if (tp.isa<FloatType>())
149     return builder.getFloatAttr(tp, 1.0);
150   if (tp.isa<IndexType>())
151     return builder.getIndexAttr(1);
152   if (auto intTp = tp.dyn_cast<IntegerType>())
153     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
154   if (tp.isa<RankedTensorType, VectorType>()) {
155     auto shapedTp = tp.cast<ShapedType>();
156     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
157       return DenseElementsAttr::get(shapedTp, one);
158   }
159   llvm_unreachable("Unsupported attribute type");
160 }
161 
162 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
163                                         Value v) {
164   Type tp = v.getType();
165   Value zero = constantZero(builder, loc, tp);
166   if (tp.isa<FloatType>())
167     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
168                                          zero);
169   if (tp.isIntOrIndex())
170     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
171                                          zero);
172   llvm_unreachable("Non-numeric type");
173 }
174