1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "CodegenUtils.h"
10 
11 #include "mlir/IR/Types.h"
12 #include "mlir/IR/Value.h"
13 
14 using namespace mlir;
15 using namespace mlir::sparse_tensor;
16 
17 //===----------------------------------------------------------------------===//
18 // ExecutionEngine/SparseTensorUtils helper functions.
19 //===----------------------------------------------------------------------===//
20 
21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
22   switch (width) {
23   default:
24     return OverheadType::kU64;
25   case 32:
26     return OverheadType::kU32;
27   case 16:
28     return OverheadType::kU16;
29   case 8:
30     return OverheadType::kU8;
31   }
32 }
33 
34 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
35   switch (ot) {
36   case OverheadType::kU64:
37     return builder.getIntegerType(64);
38   case OverheadType::kU32:
39     return builder.getIntegerType(32);
40   case OverheadType::kU16:
41     return builder.getIntegerType(16);
42   case OverheadType::kU8:
43     return builder.getIntegerType(8);
44   }
45   llvm_unreachable("Unknown OverheadType");
46 }
47 
48 Type mlir::sparse_tensor::getPointerOverheadType(
49     Builder &builder, const SparseTensorEncodingAttr &enc) {
50   // NOTE(wrengr): This workaround will be fixed in D115010.
51   unsigned width = enc.getPointerBitWidth();
52   if (width == 0)
53     return builder.getIndexType();
54   return getOverheadType(builder, overheadTypeEncoding(width));
55 }
56 
57 Type mlir::sparse_tensor::getIndexOverheadType(
58     Builder &builder, const SparseTensorEncodingAttr &enc) {
59   // NOTE(wrengr): This workaround will be fixed in D115010.
60   unsigned width = enc.getIndexBitWidth();
61   if (width == 0)
62     return builder.getIndexType();
63   return getOverheadType(builder, overheadTypeEncoding(width));
64 }
65 
66 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
67   if (elemTp.isF64())
68     return PrimaryType::kF64;
69   if (elemTp.isF32())
70     return PrimaryType::kF32;
71   if (elemTp.isInteger(64))
72     return PrimaryType::kI64;
73   if (elemTp.isInteger(32))
74     return PrimaryType::kI32;
75   if (elemTp.isInteger(16))
76     return PrimaryType::kI16;
77   if (elemTp.isInteger(8))
78     return PrimaryType::kI8;
79   llvm_unreachable("Unknown primary type");
80 }
81 
82 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding(
83     SparseTensorEncodingAttr::DimLevelType dlt) {
84   switch (dlt) {
85   case SparseTensorEncodingAttr::DimLevelType::Dense:
86     return DimLevelType::kDense;
87   case SparseTensorEncodingAttr::DimLevelType::Compressed:
88     return DimLevelType::kCompressed;
89   case SparseTensorEncodingAttr::DimLevelType::Singleton:
90     return DimLevelType::kSingleton;
91   }
92   llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // Misc code generators.
97 //===----------------------------------------------------------------------===//
98 
99 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
100   if (tp.isa<FloatType>())
101     return builder.getFloatAttr(tp, 1.0);
102   if (tp.isa<IndexType>())
103     return builder.getIndexAttr(1);
104   if (auto intTp = tp.dyn_cast<IntegerType>())
105     return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
106   if (tp.isa<RankedTensorType, VectorType>()) {
107     auto shapedTp = tp.cast<ShapedType>();
108     if (auto one = getOneAttr(builder, shapedTp.getElementType()))
109       return DenseElementsAttr::get(shapedTp, one);
110   }
111   llvm_unreachable("Unsupported attribute type");
112 }
113 
114 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
115                                         Value v) {
116   Type tp = v.getType();
117   Value zero = constantZero(builder, loc, tp);
118   if (tp.isa<FloatType>())
119     return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
120                                          zero);
121   if (tp.isIntOrIndex())
122     return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
123                                          zero);
124   llvm_unreachable("Non-numeric type");
125 }
126