1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/IR/Types.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 using namespace mlir::sparse_tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExecutionEngine/SparseTensorUtils helper functions. 19 //===----------------------------------------------------------------------===// 20 21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22 switch (width) { 23 case 64: 24 return OverheadType::kU64; 25 case 32: 26 return OverheadType::kU32; 27 case 16: 28 return OverheadType::kU16; 29 case 8: 30 return OverheadType::kU8; 31 case 0: 32 return OverheadType::kIndex; 33 } 34 llvm_unreachable("Unsupported overhead bitwidth"); 35 } 36 37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 38 if (tp.isIndex()) 39 return OverheadType::kIndex; 40 if (auto intTp = tp.dyn_cast<IntegerType>()) 41 return overheadTypeEncoding(intTp.getWidth()); 42 llvm_unreachable("Unknown overhead type"); 43 } 44 45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 46 switch (ot) { 47 case OverheadType::kIndex: 48 return builder.getIndexType(); 49 case OverheadType::kU64: 50 return builder.getIntegerType(64); 51 case OverheadType::kU32: 52 return builder.getIntegerType(32); 53 case OverheadType::kU16: 54 return builder.getIntegerType(16); 55 case OverheadType::kU8: 56 return builder.getIntegerType(8); 57 } 58 llvm_unreachable("Unknown OverheadType"); 59 } 60 61 OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( 62 const SparseTensorEncodingAttr &enc) { 63 return overheadTypeEncoding(enc.getPointerBitWidth()); 64 } 65 66 OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( 67 const SparseTensorEncodingAttr &enc) { 68 return overheadTypeEncoding(enc.getIndexBitWidth()); 69 } 70 71 Type mlir::sparse_tensor::getPointerOverheadType( 72 Builder &builder, const SparseTensorEncodingAttr &enc) { 73 return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); 74 } 75 76 Type mlir::sparse_tensor::getIndexOverheadType( 77 Builder &builder, const SparseTensorEncodingAttr &enc) { 78 return getOverheadType(builder, indexOverheadTypeEncoding(enc)); 79 } 80 81 // TODO: Adjust the naming convention for the constructors of `OverheadType` 82 // and the function-suffix for `kIndex` so we can use the `FOREVERY_O` 83 // x-macro here instead of `FOREVERY_FIXED_O`; to further reduce the 84 // possibility of typo bugs or things getting out of sync. 85 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 86 switch (ot) { 87 case OverheadType::kIndex: 88 return ""; 89 #define CASE(ONAME, O) \ 90 case OverheadType::kU##ONAME: \ 91 return #ONAME; 92 FOREVERY_FIXED_O(CASE) 93 #undef CASE 94 } 95 llvm_unreachable("Unknown OverheadType"); 96 } 97 98 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 99 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 100 } 101 102 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 103 if (elemTp.isF64()) 104 return PrimaryType::kF64; 105 if (elemTp.isF32()) 106 return PrimaryType::kF32; 107 if (elemTp.isInteger(64)) 108 return PrimaryType::kI64; 109 if (elemTp.isInteger(32)) 110 return PrimaryType::kI32; 111 if (elemTp.isInteger(16)) 112 return PrimaryType::kI16; 113 if (elemTp.isInteger(8)) 114 return PrimaryType::kI8; 115 if (auto complexTp = elemTp.dyn_cast<ComplexType>()) { 116 auto complexEltTp = complexTp.getElementType(); 117 if (complexEltTp.isF64()) 118 return PrimaryType::kC64; 119 if (complexEltTp.isF32()) 120 return PrimaryType::kC32; 121 } 122 llvm_unreachable("Unknown primary type"); 123 } 124 125 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 126 switch (pt) { 127 #define CASE(VNAME, V) \ 128 case PrimaryType::k##VNAME: \ 129 return #VNAME; 130 FOREVERY_V(CASE) 131 #undef CASE 132 } 133 llvm_unreachable("Unknown PrimaryType"); 134 } 135 136 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 137 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 138 } 139 140 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 141 SparseTensorEncodingAttr::DimLevelType dlt) { 142 switch (dlt) { 143 case SparseTensorEncodingAttr::DimLevelType::Dense: 144 return DimLevelType::kDense; 145 case SparseTensorEncodingAttr::DimLevelType::Compressed: 146 return DimLevelType::kCompressed; 147 case SparseTensorEncodingAttr::DimLevelType::Singleton: 148 return DimLevelType::kSingleton; 149 } 150 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 151 } 152 153 //===----------------------------------------------------------------------===// 154 // Misc code generators. 155 //===----------------------------------------------------------------------===// 156 157 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 158 if (tp.isa<FloatType>()) 159 return builder.getFloatAttr(tp, 1.0); 160 if (tp.isa<IndexType>()) 161 return builder.getIndexAttr(1); 162 if (auto intTp = tp.dyn_cast<IntegerType>()) 163 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 164 if (tp.isa<RankedTensorType, VectorType>()) { 165 auto shapedTp = tp.cast<ShapedType>(); 166 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 167 return DenseElementsAttr::get(shapedTp, one); 168 } 169 llvm_unreachable("Unsupported attribute type"); 170 } 171 172 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 173 Value v) { 174 Type tp = v.getType(); 175 Value zero = constantZero(builder, loc, tp); 176 if (tp.isa<FloatType>()) 177 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 178 zero); 179 if (tp.isIntOrIndex()) 180 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 181 zero); 182 if (tp.dyn_cast<ComplexType>()) 183 return builder.create<complex::NotEqualOp>(loc, v, zero); 184 llvm_unreachable("Non-numeric type"); 185 } 186