1*85b8d03eSwren romano //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 2*85b8d03eSwren romano // 3*85b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*85b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information. 5*85b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*85b8d03eSwren romano // 7*85b8d03eSwren romano //===----------------------------------------------------------------------===// 8*85b8d03eSwren romano 9*85b8d03eSwren romano #include "CodegenUtils.h" 10*85b8d03eSwren romano 11*85b8d03eSwren romano #include "mlir/IR/Types.h" 12*85b8d03eSwren romano #include "mlir/IR/Value.h" 13*85b8d03eSwren romano 14*85b8d03eSwren romano using namespace mlir; 15*85b8d03eSwren romano using namespace mlir::sparse_tensor; 16*85b8d03eSwren romano 17*85b8d03eSwren romano //===----------------------------------------------------------------------===// 18*85b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions. 19*85b8d03eSwren romano //===----------------------------------------------------------------------===// 20*85b8d03eSwren romano 21*85b8d03eSwren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 22*85b8d03eSwren romano switch (width) { 23*85b8d03eSwren romano default: 24*85b8d03eSwren romano return OverheadType::kU64; 25*85b8d03eSwren romano case 32: 26*85b8d03eSwren romano return OverheadType::kU32; 27*85b8d03eSwren romano case 16: 28*85b8d03eSwren romano return OverheadType::kU16; 29*85b8d03eSwren romano case 8: 30*85b8d03eSwren romano return OverheadType::kU8; 31*85b8d03eSwren romano } 32*85b8d03eSwren romano } 33*85b8d03eSwren romano 34*85b8d03eSwren romano Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 35*85b8d03eSwren romano switch (ot) { 36*85b8d03eSwren romano case OverheadType::kU64: 37*85b8d03eSwren romano return builder.getIntegerType(64); 38*85b8d03eSwren romano case OverheadType::kU32: 39*85b8d03eSwren romano return builder.getIntegerType(32); 40*85b8d03eSwren romano case OverheadType::kU16: 41*85b8d03eSwren romano return builder.getIntegerType(16); 42*85b8d03eSwren romano case OverheadType::kU8: 43*85b8d03eSwren romano return builder.getIntegerType(8); 44*85b8d03eSwren romano } 45*85b8d03eSwren romano llvm_unreachable("Unknown OverheadType"); 46*85b8d03eSwren romano } 47*85b8d03eSwren romano 48*85b8d03eSwren romano Type mlir::sparse_tensor::getPointerOverheadType( 49*85b8d03eSwren romano Builder &builder, const SparseTensorEncodingAttr &enc) { 50*85b8d03eSwren romano // NOTE(wrengr): This workaround will be fixed in D115010. 51*85b8d03eSwren romano unsigned width = enc.getPointerBitWidth(); 52*85b8d03eSwren romano if (width == 0) 53*85b8d03eSwren romano return builder.getIndexType(); 54*85b8d03eSwren romano return getOverheadType(builder, overheadTypeEncoding(width)); 55*85b8d03eSwren romano } 56*85b8d03eSwren romano 57*85b8d03eSwren romano Type mlir::sparse_tensor::getIndexOverheadType( 58*85b8d03eSwren romano Builder &builder, const SparseTensorEncodingAttr &enc) { 59*85b8d03eSwren romano // NOTE(wrengr): This workaround will be fixed in D115010. 60*85b8d03eSwren romano unsigned width = enc.getIndexBitWidth(); 61*85b8d03eSwren romano if (width == 0) 62*85b8d03eSwren romano return builder.getIndexType(); 63*85b8d03eSwren romano return getOverheadType(builder, overheadTypeEncoding(width)); 64*85b8d03eSwren romano } 65*85b8d03eSwren romano 66*85b8d03eSwren romano PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 67*85b8d03eSwren romano if (elemTp.isF64()) 68*85b8d03eSwren romano return PrimaryType::kF64; 69*85b8d03eSwren romano if (elemTp.isF32()) 70*85b8d03eSwren romano return PrimaryType::kF32; 71*85b8d03eSwren romano if (elemTp.isInteger(64)) 72*85b8d03eSwren romano return PrimaryType::kI64; 73*85b8d03eSwren romano if (elemTp.isInteger(32)) 74*85b8d03eSwren romano return PrimaryType::kI32; 75*85b8d03eSwren romano if (elemTp.isInteger(16)) 76*85b8d03eSwren romano return PrimaryType::kI16; 77*85b8d03eSwren romano if (elemTp.isInteger(8)) 78*85b8d03eSwren romano return PrimaryType::kI8; 79*85b8d03eSwren romano llvm_unreachable("Unknown primary type"); 80*85b8d03eSwren romano } 81*85b8d03eSwren romano 82*85b8d03eSwren romano DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 83*85b8d03eSwren romano SparseTensorEncodingAttr::DimLevelType dlt) { 84*85b8d03eSwren romano switch (dlt) { 85*85b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Dense: 86*85b8d03eSwren romano return DimLevelType::kDense; 87*85b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Compressed: 88*85b8d03eSwren romano return DimLevelType::kCompressed; 89*85b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Singleton: 90*85b8d03eSwren romano return DimLevelType::kSingleton; 91*85b8d03eSwren romano } 92*85b8d03eSwren romano llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 93*85b8d03eSwren romano } 94*85b8d03eSwren romano 95*85b8d03eSwren romano //===----------------------------------------------------------------------===// 96*85b8d03eSwren romano // Misc code generators. 97*85b8d03eSwren romano //===----------------------------------------------------------------------===// 98*85b8d03eSwren romano 99*85b8d03eSwren romano mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 100*85b8d03eSwren romano if (tp.isa<FloatType>()) 101*85b8d03eSwren romano return builder.getFloatAttr(tp, 1.0); 102*85b8d03eSwren romano if (tp.isa<IndexType>()) 103*85b8d03eSwren romano return builder.getIndexAttr(1); 104*85b8d03eSwren romano if (auto intTp = tp.dyn_cast<IntegerType>()) 105*85b8d03eSwren romano return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 106*85b8d03eSwren romano if (tp.isa<RankedTensorType, VectorType>()) { 107*85b8d03eSwren romano auto shapedTp = tp.cast<ShapedType>(); 108*85b8d03eSwren romano if (auto one = getOneAttr(builder, shapedTp.getElementType())) 109*85b8d03eSwren romano return DenseElementsAttr::get(shapedTp, one); 110*85b8d03eSwren romano } 111*85b8d03eSwren romano llvm_unreachable("Unsupported attribute type"); 112*85b8d03eSwren romano } 113*85b8d03eSwren romano 114*85b8d03eSwren romano Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 115*85b8d03eSwren romano Value v) { 116*85b8d03eSwren romano Type tp = v.getType(); 117*85b8d03eSwren romano Value zero = constantZero(builder, loc, tp); 118*85b8d03eSwren romano if (tp.isa<FloatType>()) 119*85b8d03eSwren romano return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 120*85b8d03eSwren romano zero); 121*85b8d03eSwren romano if (tp.isIntOrIndex()) 122*85b8d03eSwren romano return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 123*85b8d03eSwren romano zero); 124*85b8d03eSwren romano llvm_unreachable("Non-numeric type"); 125*85b8d03eSwren romano } 126