1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/IR/Types.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 using namespace mlir::sparse_tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExecutionEngine/SparseTensorUtils helper functions. 19 //===----------------------------------------------------------------------===// 20 21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22 switch (width) { 23 case 64: 24 return OverheadType::kU64; 25 case 32: 26 return OverheadType::kU32; 27 case 16: 28 return OverheadType::kU16; 29 case 8: 30 return OverheadType::kU8; 31 case 0: 32 return OverheadType::kIndex; 33 } 34 llvm_unreachable("Unsupported overhead bitwidth"); 35 } 36 37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 38 if (tp.isIndex()) 39 return OverheadType::kIndex; 40 if (auto intTp = tp.dyn_cast<IntegerType>()) 41 return overheadTypeEncoding(intTp.getWidth()); 42 llvm_unreachable("Unknown overhead type"); 43 } 44 45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 46 switch (ot) { 47 case OverheadType::kIndex: 48 return builder.getIndexType(); 49 case OverheadType::kU64: 50 return builder.getIntegerType(64); 51 case OverheadType::kU32: 52 return builder.getIntegerType(32); 53 case OverheadType::kU16: 54 return builder.getIntegerType(16); 55 case OverheadType::kU8: 56 return builder.getIntegerType(8); 57 } 58 llvm_unreachable("Unknown OverheadType"); 59 } 60 61 Type mlir::sparse_tensor::getPointerOverheadType( 62 Builder &builder, const SparseTensorEncodingAttr &enc) { 63 return getOverheadType(builder, 64 overheadTypeEncoding(enc.getPointerBitWidth())); 65 } 66 67 Type mlir::sparse_tensor::getIndexOverheadType( 68 Builder &builder, const SparseTensorEncodingAttr &enc) { 69 return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); 70 } 71 72 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 73 switch (ot) { 74 case OverheadType::kIndex: 75 return ""; 76 case OverheadType::kU64: 77 return "64"; 78 case OverheadType::kU32: 79 return "32"; 80 case OverheadType::kU16: 81 return "16"; 82 case OverheadType::kU8: 83 return "8"; 84 } 85 llvm_unreachable("Unknown OverheadType"); 86 } 87 88 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 89 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 90 } 91 92 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 93 if (elemTp.isF64()) 94 return PrimaryType::kF64; 95 if (elemTp.isF32()) 96 return PrimaryType::kF32; 97 if (elemTp.isInteger(64)) 98 return PrimaryType::kI64; 99 if (elemTp.isInteger(32)) 100 return PrimaryType::kI32; 101 if (elemTp.isInteger(16)) 102 return PrimaryType::kI16; 103 if (elemTp.isInteger(8)) 104 return PrimaryType::kI8; 105 llvm_unreachable("Unknown primary type"); 106 } 107 108 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 109 switch (pt) { 110 case PrimaryType::kF64: 111 return "F64"; 112 case PrimaryType::kF32: 113 return "F32"; 114 case PrimaryType::kI64: 115 return "I64"; 116 case PrimaryType::kI32: 117 return "I32"; 118 case PrimaryType::kI16: 119 return "I16"; 120 case PrimaryType::kI8: 121 return "I8"; 122 } 123 llvm_unreachable("Unknown PrimaryType"); 124 } 125 126 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 127 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 128 } 129 130 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 131 SparseTensorEncodingAttr::DimLevelType dlt) { 132 switch (dlt) { 133 case SparseTensorEncodingAttr::DimLevelType::Dense: 134 return DimLevelType::kDense; 135 case SparseTensorEncodingAttr::DimLevelType::Compressed: 136 return DimLevelType::kCompressed; 137 case SparseTensorEncodingAttr::DimLevelType::Singleton: 138 return DimLevelType::kSingleton; 139 } 140 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Misc code generators. 145 //===----------------------------------------------------------------------===// 146 147 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 148 if (tp.isa<FloatType>()) 149 return builder.getFloatAttr(tp, 1.0); 150 if (tp.isa<IndexType>()) 151 return builder.getIndexAttr(1); 152 if (auto intTp = tp.dyn_cast<IntegerType>()) 153 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 154 if (tp.isa<RankedTensorType, VectorType>()) { 155 auto shapedTp = tp.cast<ShapedType>(); 156 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 157 return DenseElementsAttr::get(shapedTp, one); 158 } 159 llvm_unreachable("Unsupported attribute type"); 160 } 161 162 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 163 Value v) { 164 Type tp = v.getType(); 165 Value zero = constantZero(builder, loc, tp); 166 if (tp.isa<FloatType>()) 167 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 168 zero); 169 if (tp.isIntOrIndex()) 170 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 171 zero); 172 llvm_unreachable("Non-numeric type"); 173 } 174