1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/IR/Types.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 using namespace mlir::sparse_tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExecutionEngine/SparseTensorUtils helper functions. 19 //===----------------------------------------------------------------------===// 20 21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22 switch (width) { 23 case 64: 24 return OverheadType::kU64; 25 case 32: 26 return OverheadType::kU32; 27 case 16: 28 return OverheadType::kU16; 29 case 8: 30 return OverheadType::kU8; 31 case 0: 32 return OverheadType::kIndex; 33 } 34 llvm_unreachable("Unsupported overhead bitwidth"); 35 } 36 37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 38 if (tp.isIndex()) 39 return OverheadType::kIndex; 40 if (auto intTp = tp.dyn_cast<IntegerType>()) 41 return overheadTypeEncoding(intTp.getWidth()); 42 llvm_unreachable("Unknown overhead type"); 43 } 44 45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 46 switch (ot) { 47 case OverheadType::kIndex: 48 return builder.getIndexType(); 49 case OverheadType::kU64: 50 return builder.getIntegerType(64); 51 case OverheadType::kU32: 52 return builder.getIntegerType(32); 53 case OverheadType::kU16: 54 return builder.getIntegerType(16); 55 case OverheadType::kU8: 56 return builder.getIntegerType(8); 57 } 58 llvm_unreachable("Unknown OverheadType"); 59 } 60 61 OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( 62 const SparseTensorEncodingAttr &enc) { 63 return overheadTypeEncoding(enc.getPointerBitWidth()); 64 } 65 66 OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( 67 const SparseTensorEncodingAttr &enc) { 68 return overheadTypeEncoding(enc.getIndexBitWidth()); 69 } 70 71 Type mlir::sparse_tensor::getPointerOverheadType( 72 Builder &builder, const SparseTensorEncodingAttr &enc) { 73 return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); 74 } 75 76 Type mlir::sparse_tensor::getIndexOverheadType( 77 Builder &builder, const SparseTensorEncodingAttr &enc) { 78 return getOverheadType(builder, indexOverheadTypeEncoding(enc)); 79 } 80 81 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 82 switch (ot) { 83 case OverheadType::kIndex: 84 return ""; 85 case OverheadType::kU64: 86 return "64"; 87 case OverheadType::kU32: 88 return "32"; 89 case OverheadType::kU16: 90 return "16"; 91 case OverheadType::kU8: 92 return "8"; 93 } 94 llvm_unreachable("Unknown OverheadType"); 95 } 96 97 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 98 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 99 } 100 101 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 102 if (elemTp.isF64()) 103 return PrimaryType::kF64; 104 if (elemTp.isF32()) 105 return PrimaryType::kF32; 106 if (elemTp.isInteger(64)) 107 return PrimaryType::kI64; 108 if (elemTp.isInteger(32)) 109 return PrimaryType::kI32; 110 if (elemTp.isInteger(16)) 111 return PrimaryType::kI16; 112 if (elemTp.isInteger(8)) 113 return PrimaryType::kI8; 114 llvm_unreachable("Unknown primary type"); 115 } 116 117 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 118 switch (pt) { 119 case PrimaryType::kF64: 120 return "F64"; 121 case PrimaryType::kF32: 122 return "F32"; 123 case PrimaryType::kI64: 124 return "I64"; 125 case PrimaryType::kI32: 126 return "I32"; 127 case PrimaryType::kI16: 128 return "I16"; 129 case PrimaryType::kI8: 130 return "I8"; 131 } 132 llvm_unreachable("Unknown PrimaryType"); 133 } 134 135 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 136 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 137 } 138 139 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 140 SparseTensorEncodingAttr::DimLevelType dlt) { 141 switch (dlt) { 142 case SparseTensorEncodingAttr::DimLevelType::Dense: 143 return DimLevelType::kDense; 144 case SparseTensorEncodingAttr::DimLevelType::Compressed: 145 return DimLevelType::kCompressed; 146 case SparseTensorEncodingAttr::DimLevelType::Singleton: 147 return DimLevelType::kSingleton; 148 } 149 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 150 } 151 152 //===----------------------------------------------------------------------===// 153 // Misc code generators. 154 //===----------------------------------------------------------------------===// 155 156 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 157 if (tp.isa<FloatType>()) 158 return builder.getFloatAttr(tp, 1.0); 159 if (tp.isa<IndexType>()) 160 return builder.getIndexAttr(1); 161 if (auto intTp = tp.dyn_cast<IntegerType>()) 162 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 163 if (tp.isa<RankedTensorType, VectorType>()) { 164 auto shapedTp = tp.cast<ShapedType>(); 165 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 166 return DenseElementsAttr::get(shapedTp, one); 167 } 168 llvm_unreachable("Unsupported attribute type"); 169 } 170 171 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 172 Value v) { 173 Type tp = v.getType(); 174 Value zero = constantZero(builder, loc, tp); 175 if (tp.isa<FloatType>()) 176 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 177 zero); 178 if (tp.isIntOrIndex()) 179 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 180 zero); 181 llvm_unreachable("Non-numeric type"); 182 } 183