1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/IR/Types.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 using namespace mlir::sparse_tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExecutionEngine/SparseTensorUtils helper functions. 19 //===----------------------------------------------------------------------===// 20 21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22 switch (width) { 23 case 64: 24 return OverheadType::kU64; 25 case 32: 26 return OverheadType::kU32; 27 case 16: 28 return OverheadType::kU16; 29 case 8: 30 return OverheadType::kU8; 31 case 0: 32 return OverheadType::kIndex; 33 } 34 llvm_unreachable("Unsupported overhead bitwidth"); 35 } 36 37 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 38 switch (ot) { 39 case OverheadType::kIndex: 40 return builder.getIndexType(); 41 case OverheadType::kU64: 42 return builder.getIntegerType(64); 43 case OverheadType::kU32: 44 return builder.getIntegerType(32); 45 case OverheadType::kU16: 46 return builder.getIntegerType(16); 47 case OverheadType::kU8: 48 return builder.getIntegerType(8); 49 } 50 llvm_unreachable("Unknown OverheadType"); 51 } 52 53 Type mlir::sparse_tensor::getPointerOverheadType( 54 Builder &builder, const SparseTensorEncodingAttr &enc) { 55 return getOverheadType(builder, 56 overheadTypeEncoding(enc.getPointerBitWidth())); 57 } 58 59 Type mlir::sparse_tensor::getIndexOverheadType( 60 Builder &builder, const SparseTensorEncodingAttr &enc) { 61 return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); 62 } 63 64 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 65 if (elemTp.isF64()) 66 return PrimaryType::kF64; 67 if (elemTp.isF32()) 68 return PrimaryType::kF32; 69 if (elemTp.isInteger(64)) 70 return PrimaryType::kI64; 71 if (elemTp.isInteger(32)) 72 return PrimaryType::kI32; 73 if (elemTp.isInteger(16)) 74 return PrimaryType::kI16; 75 if (elemTp.isInteger(8)) 76 return PrimaryType::kI8; 77 llvm_unreachable("Unknown primary type"); 78 } 79 80 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 81 SparseTensorEncodingAttr::DimLevelType dlt) { 82 switch (dlt) { 83 case SparseTensorEncodingAttr::DimLevelType::Dense: 84 return DimLevelType::kDense; 85 case SparseTensorEncodingAttr::DimLevelType::Compressed: 86 return DimLevelType::kCompressed; 87 case SparseTensorEncodingAttr::DimLevelType::Singleton: 88 return DimLevelType::kSingleton; 89 } 90 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // Misc code generators. 95 //===----------------------------------------------------------------------===// 96 97 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 98 if (tp.isa<FloatType>()) 99 return builder.getFloatAttr(tp, 1.0); 100 if (tp.isa<IndexType>()) 101 return builder.getIndexAttr(1); 102 if (auto intTp = tp.dyn_cast<IntegerType>()) 103 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 104 if (tp.isa<RankedTensorType, VectorType>()) { 105 auto shapedTp = tp.cast<ShapedType>(); 106 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 107 return DenseElementsAttr::get(shapedTp, one); 108 } 109 llvm_unreachable("Unsupported attribute type"); 110 } 111 112 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 113 Value v) { 114 Type tp = v.getType(); 115 Value zero = constantZero(builder, loc, tp); 116 if (tp.isa<FloatType>()) 117 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 118 zero); 119 if (tp.isIntOrIndex()) 120 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 121 zero); 122 llvm_unreachable("Non-numeric type"); 123 } 124