1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/IR/Types.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 using namespace mlir::sparse_tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExecutionEngine/SparseTensorUtils helper functions. 19 //===----------------------------------------------------------------------===// 20 21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22 switch (width) { 23 case 64: 24 return OverheadType::kU64; 25 case 32: 26 return OverheadType::kU32; 27 case 16: 28 return OverheadType::kU16; 29 case 8: 30 return OverheadType::kU8; 31 case 0: 32 return OverheadType::kIndex; 33 } 34 llvm_unreachable("Unsupported overhead bitwidth"); 35 } 36 37 OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 38 if (tp.isIndex()) 39 return OverheadType::kIndex; 40 if (auto intTp = tp.dyn_cast<IntegerType>()) 41 return overheadTypeEncoding(intTp.getWidth()); 42 llvm_unreachable("Unknown overhead type"); 43 } 44 45 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 46 switch (ot) { 47 case OverheadType::kIndex: 48 return builder.getIndexType(); 49 case OverheadType::kU64: 50 return builder.getIntegerType(64); 51 case OverheadType::kU32: 52 return builder.getIntegerType(32); 53 case OverheadType::kU16: 54 return builder.getIntegerType(16); 55 case OverheadType::kU8: 56 return builder.getIntegerType(8); 57 } 58 llvm_unreachable("Unknown OverheadType"); 59 } 60 61 OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( 62 const SparseTensorEncodingAttr &enc) { 63 return overheadTypeEncoding(enc.getPointerBitWidth()); 64 } 65 66 OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( 67 const SparseTensorEncodingAttr &enc) { 68 return overheadTypeEncoding(enc.getIndexBitWidth()); 69 } 70 71 Type mlir::sparse_tensor::getPointerOverheadType( 72 Builder &builder, const SparseTensorEncodingAttr &enc) { 73 return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); 74 } 75 76 Type mlir::sparse_tensor::getIndexOverheadType( 77 Builder &builder, const SparseTensorEncodingAttr &enc) { 78 return getOverheadType(builder, indexOverheadTypeEncoding(enc)); 79 } 80 81 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 82 switch (ot) { 83 case OverheadType::kIndex: 84 return ""; 85 case OverheadType::kU64: 86 return "64"; 87 case OverheadType::kU32: 88 return "32"; 89 case OverheadType::kU16: 90 return "16"; 91 case OverheadType::kU8: 92 return "8"; 93 } 94 llvm_unreachable("Unknown OverheadType"); 95 } 96 97 StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 98 return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 99 } 100 101 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 102 if (elemTp.isF64()) 103 return PrimaryType::kF64; 104 if (elemTp.isF32()) 105 return PrimaryType::kF32; 106 if (elemTp.isInteger(64)) 107 return PrimaryType::kI64; 108 if (elemTp.isInteger(32)) 109 return PrimaryType::kI32; 110 if (elemTp.isInteger(16)) 111 return PrimaryType::kI16; 112 if (elemTp.isInteger(8)) 113 return PrimaryType::kI8; 114 if (auto complexTp = elemTp.dyn_cast<ComplexType>()) { 115 auto complexEltTp = complexTp.getElementType(); 116 if (complexEltTp.isF64()) 117 return PrimaryType::kC64; 118 if (complexEltTp.isF32()) 119 return PrimaryType::kC32; 120 } 121 llvm_unreachable("Unknown primary type"); 122 } 123 124 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 125 switch (pt) { 126 case PrimaryType::kF64: 127 return "F64"; 128 case PrimaryType::kF32: 129 return "F32"; 130 case PrimaryType::kI64: 131 return "I64"; 132 case PrimaryType::kI32: 133 return "I32"; 134 case PrimaryType::kI16: 135 return "I16"; 136 case PrimaryType::kI8: 137 return "I8"; 138 case PrimaryType::kC64: 139 return "C64"; 140 case PrimaryType::kC32: 141 return "C32"; 142 } 143 llvm_unreachable("Unknown PrimaryType"); 144 } 145 146 StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 147 return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 148 } 149 150 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 151 SparseTensorEncodingAttr::DimLevelType dlt) { 152 switch (dlt) { 153 case SparseTensorEncodingAttr::DimLevelType::Dense: 154 return DimLevelType::kDense; 155 case SparseTensorEncodingAttr::DimLevelType::Compressed: 156 return DimLevelType::kCompressed; 157 case SparseTensorEncodingAttr::DimLevelType::Singleton: 158 return DimLevelType::kSingleton; 159 } 160 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 161 } 162 163 //===----------------------------------------------------------------------===// 164 // Misc code generators. 165 //===----------------------------------------------------------------------===// 166 167 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 168 if (tp.isa<FloatType>()) 169 return builder.getFloatAttr(tp, 1.0); 170 if (tp.isa<IndexType>()) 171 return builder.getIndexAttr(1); 172 if (auto intTp = tp.dyn_cast<IntegerType>()) 173 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 174 if (tp.isa<RankedTensorType, VectorType>()) { 175 auto shapedTp = tp.cast<ShapedType>(); 176 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 177 return DenseElementsAttr::get(shapedTp, one); 178 } 179 llvm_unreachable("Unsupported attribute type"); 180 } 181 182 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 183 Value v) { 184 Type tp = v.getType(); 185 Value zero = constantZero(builder, loc, tp); 186 if (tp.isa<FloatType>()) 187 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 188 zero); 189 if (tp.isIntOrIndex()) 190 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 191 zero); 192 if (tp.dyn_cast<ComplexType>()) 193 return builder.create<complex::NotEqualOp>(loc, v, zero); 194 llvm_unreachable("Non-numeric type"); 195 } 196