1 //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "CodegenUtils.h" 10 11 #include "mlir/IR/Types.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 using namespace mlir::sparse_tensor; 16 17 //===----------------------------------------------------------------------===// 18 // ExecutionEngine/SparseTensorUtils helper functions. 19 //===----------------------------------------------------------------------===// 20 21 OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22 switch (width) { 23 default: 24 return OverheadType::kU64; 25 case 32: 26 return OverheadType::kU32; 27 case 16: 28 return OverheadType::kU16; 29 case 8: 30 return OverheadType::kU8; 31 } 32 } 33 34 Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 35 switch (ot) { 36 case OverheadType::kU64: 37 return builder.getIntegerType(64); 38 case OverheadType::kU32: 39 return builder.getIntegerType(32); 40 case OverheadType::kU16: 41 return builder.getIntegerType(16); 42 case OverheadType::kU8: 43 return builder.getIntegerType(8); 44 } 45 llvm_unreachable("Unknown OverheadType"); 46 } 47 48 Type mlir::sparse_tensor::getPointerOverheadType( 49 Builder &builder, const SparseTensorEncodingAttr &enc) { 50 // NOTE(wrengr): This workaround will be fixed in D115010. 51 unsigned width = enc.getPointerBitWidth(); 52 if (width == 0) 53 return builder.getIndexType(); 54 return getOverheadType(builder, overheadTypeEncoding(width)); 55 } 56 57 Type mlir::sparse_tensor::getIndexOverheadType( 58 Builder &builder, const SparseTensorEncodingAttr &enc) { 59 // NOTE(wrengr): This workaround will be fixed in D115010. 60 unsigned width = enc.getIndexBitWidth(); 61 if (width == 0) 62 return builder.getIndexType(); 63 return getOverheadType(builder, overheadTypeEncoding(width)); 64 } 65 66 PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 67 if (elemTp.isF64()) 68 return PrimaryType::kF64; 69 if (elemTp.isF32()) 70 return PrimaryType::kF32; 71 if (elemTp.isInteger(64)) 72 return PrimaryType::kI64; 73 if (elemTp.isInteger(32)) 74 return PrimaryType::kI32; 75 if (elemTp.isInteger(16)) 76 return PrimaryType::kI16; 77 if (elemTp.isInteger(8)) 78 return PrimaryType::kI8; 79 llvm_unreachable("Unknown primary type"); 80 } 81 82 DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 83 SparseTensorEncodingAttr::DimLevelType dlt) { 84 switch (dlt) { 85 case SparseTensorEncodingAttr::DimLevelType::Dense: 86 return DimLevelType::kDense; 87 case SparseTensorEncodingAttr::DimLevelType::Compressed: 88 return DimLevelType::kCompressed; 89 case SparseTensorEncodingAttr::DimLevelType::Singleton: 90 return DimLevelType::kSingleton; 91 } 92 llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 93 } 94 95 //===----------------------------------------------------------------------===// 96 // Misc code generators. 97 //===----------------------------------------------------------------------===// 98 99 mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 100 if (tp.isa<FloatType>()) 101 return builder.getFloatAttr(tp, 1.0); 102 if (tp.isa<IndexType>()) 103 return builder.getIndexAttr(1); 104 if (auto intTp = tp.dyn_cast<IntegerType>()) 105 return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 106 if (tp.isa<RankedTensorType, VectorType>()) { 107 auto shapedTp = tp.cast<ShapedType>(); 108 if (auto one = getOneAttr(builder, shapedTp.getElementType())) 109 return DenseElementsAttr::get(shapedTp, one); 110 } 111 llvm_unreachable("Unsupported attribute type"); 112 } 113 114 Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 115 Value v) { 116 Type tp = v.getType(); 117 Value zero = constantZero(builder, loc, tp); 118 if (tp.isa<FloatType>()) 119 return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 120 zero); 121 if (tp.isIntOrIndex()) 122 return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 123 zero); 124 llvm_unreachable("Non-numeric type"); 125 } 126