185b8d03eSwren romano //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 285b8d03eSwren romano // 385b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 485b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information. 585b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 685b8d03eSwren romano // 785b8d03eSwren romano //===----------------------------------------------------------------------===// 885b8d03eSwren romano 985b8d03eSwren romano #include "CodegenUtils.h" 1085b8d03eSwren romano 1185b8d03eSwren romano #include "mlir/IR/Types.h" 1285b8d03eSwren romano #include "mlir/IR/Value.h" 1385b8d03eSwren romano 1485b8d03eSwren romano using namespace mlir; 1585b8d03eSwren romano using namespace mlir::sparse_tensor; 1685b8d03eSwren romano 1785b8d03eSwren romano //===----------------------------------------------------------------------===// 1885b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions. 1985b8d03eSwren romano //===----------------------------------------------------------------------===// 2085b8d03eSwren romano 2185b8d03eSwren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 2285b8d03eSwren romano switch (width) { 23bc04a470Swren romano case 64: 2485b8d03eSwren romano return OverheadType::kU64; 2585b8d03eSwren romano case 32: 2685b8d03eSwren romano return OverheadType::kU32; 2785b8d03eSwren romano case 16: 2885b8d03eSwren romano return OverheadType::kU16; 2985b8d03eSwren romano case 8: 3085b8d03eSwren romano return OverheadType::kU8; 31bc04a470Swren romano case 0: 32bc04a470Swren romano return OverheadType::kIndex; 3385b8d03eSwren romano } 34bc04a470Swren romano llvm_unreachable("Unsupported overhead bitwidth"); 3585b8d03eSwren romano } 3685b8d03eSwren romano 37c9489225Swren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { 38c9489225Swren romano if (tp.isIndex()) 39c9489225Swren romano return OverheadType::kIndex; 40c9489225Swren romano if (auto intTp = tp.dyn_cast<IntegerType>()) 41c9489225Swren romano return overheadTypeEncoding(intTp.getWidth()); 42c9489225Swren romano llvm_unreachable("Unknown overhead type"); 43c9489225Swren romano } 44c9489225Swren romano 4585b8d03eSwren romano Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 4685b8d03eSwren romano switch (ot) { 47bc04a470Swren romano case OverheadType::kIndex: 48bc04a470Swren romano return builder.getIndexType(); 4985b8d03eSwren romano case OverheadType::kU64: 5085b8d03eSwren romano return builder.getIntegerType(64); 5185b8d03eSwren romano case OverheadType::kU32: 5285b8d03eSwren romano return builder.getIntegerType(32); 5385b8d03eSwren romano case OverheadType::kU16: 5485b8d03eSwren romano return builder.getIntegerType(16); 5585b8d03eSwren romano case OverheadType::kU8: 5685b8d03eSwren romano return builder.getIntegerType(8); 5785b8d03eSwren romano } 5885b8d03eSwren romano llvm_unreachable("Unknown OverheadType"); 5985b8d03eSwren romano } 6085b8d03eSwren romano 61ebc84664Swren romano OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( 62ebc84664Swren romano const SparseTensorEncodingAttr &enc) { 63ebc84664Swren romano return overheadTypeEncoding(enc.getPointerBitWidth()); 64ebc84664Swren romano } 65ebc84664Swren romano 66ebc84664Swren romano OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( 67ebc84664Swren romano const SparseTensorEncodingAttr &enc) { 68ebc84664Swren romano return overheadTypeEncoding(enc.getIndexBitWidth()); 69ebc84664Swren romano } 70ebc84664Swren romano 7185b8d03eSwren romano Type mlir::sparse_tensor::getPointerOverheadType( 7285b8d03eSwren romano Builder &builder, const SparseTensorEncodingAttr &enc) { 73ebc84664Swren romano return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); 7485b8d03eSwren romano } 7585b8d03eSwren romano 7685b8d03eSwren romano Type mlir::sparse_tensor::getIndexOverheadType( 7785b8d03eSwren romano Builder &builder, const SparseTensorEncodingAttr &enc) { 78ebc84664Swren romano return getOverheadType(builder, indexOverheadTypeEncoding(enc)); 7985b8d03eSwren romano } 8085b8d03eSwren romano 81c9489225Swren romano StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { 82c9489225Swren romano switch (ot) { 83c9489225Swren romano case OverheadType::kIndex: 84c9489225Swren romano return ""; 85c9489225Swren romano case OverheadType::kU64: 86c9489225Swren romano return "64"; 87c9489225Swren romano case OverheadType::kU32: 88c9489225Swren romano return "32"; 89c9489225Swren romano case OverheadType::kU16: 90c9489225Swren romano return "16"; 91c9489225Swren romano case OverheadType::kU8: 92c9489225Swren romano return "8"; 93c9489225Swren romano } 94c9489225Swren romano llvm_unreachable("Unknown OverheadType"); 95c9489225Swren romano } 96c9489225Swren romano 97c9489225Swren romano StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { 98c9489225Swren romano return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); 99c9489225Swren romano } 100c9489225Swren romano 10185b8d03eSwren romano PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 10285b8d03eSwren romano if (elemTp.isF64()) 10385b8d03eSwren romano return PrimaryType::kF64; 10485b8d03eSwren romano if (elemTp.isF32()) 10585b8d03eSwren romano return PrimaryType::kF32; 10685b8d03eSwren romano if (elemTp.isInteger(64)) 10785b8d03eSwren romano return PrimaryType::kI64; 10885b8d03eSwren romano if (elemTp.isInteger(32)) 10985b8d03eSwren romano return PrimaryType::kI32; 11085b8d03eSwren romano if (elemTp.isInteger(16)) 11185b8d03eSwren romano return PrimaryType::kI16; 11285b8d03eSwren romano if (elemTp.isInteger(8)) 11385b8d03eSwren romano return PrimaryType::kI8; 114736c1b66SAart Bik if (auto complexTp = elemTp.dyn_cast<ComplexType>()) { 115736c1b66SAart Bik auto complexEltTp = complexTp.getElementType(); 116736c1b66SAart Bik if (complexEltTp.isF64()) 117736c1b66SAart Bik return PrimaryType::kC64; 118736c1b66SAart Bik if (complexEltTp.isF32()) 119736c1b66SAart Bik return PrimaryType::kC32; 120736c1b66SAart Bik } 12185b8d03eSwren romano llvm_unreachable("Unknown primary type"); 12285b8d03eSwren romano } 12385b8d03eSwren romano 124c9489225Swren romano StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { 125c9489225Swren romano switch (pt) { 126c9489225Swren romano case PrimaryType::kF64: 127c9489225Swren romano return "F64"; 128c9489225Swren romano case PrimaryType::kF32: 129c9489225Swren romano return "F32"; 130c9489225Swren romano case PrimaryType::kI64: 131c9489225Swren romano return "I64"; 132c9489225Swren romano case PrimaryType::kI32: 133c9489225Swren romano return "I32"; 134c9489225Swren romano case PrimaryType::kI16: 135c9489225Swren romano return "I16"; 136c9489225Swren romano case PrimaryType::kI8: 137c9489225Swren romano return "I8"; 138736c1b66SAart Bik case PrimaryType::kC64: 139736c1b66SAart Bik return "C64"; 140736c1b66SAart Bik case PrimaryType::kC32: 141736c1b66SAart Bik return "C32"; 142c9489225Swren romano } 143c9489225Swren romano llvm_unreachable("Unknown PrimaryType"); 144c9489225Swren romano } 145c9489225Swren romano 146c9489225Swren romano StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { 147c9489225Swren romano return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); 148c9489225Swren romano } 149c9489225Swren romano 15085b8d03eSwren romano DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 15185b8d03eSwren romano SparseTensorEncodingAttr::DimLevelType dlt) { 15285b8d03eSwren romano switch (dlt) { 15385b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Dense: 15485b8d03eSwren romano return DimLevelType::kDense; 15585b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Compressed: 15685b8d03eSwren romano return DimLevelType::kCompressed; 15785b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Singleton: 15885b8d03eSwren romano return DimLevelType::kSingleton; 15985b8d03eSwren romano } 16085b8d03eSwren romano llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 16185b8d03eSwren romano } 16285b8d03eSwren romano 16385b8d03eSwren romano //===----------------------------------------------------------------------===// 16485b8d03eSwren romano // Misc code generators. 16585b8d03eSwren romano //===----------------------------------------------------------------------===// 16685b8d03eSwren romano 16785b8d03eSwren romano mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 16885b8d03eSwren romano if (tp.isa<FloatType>()) 16985b8d03eSwren romano return builder.getFloatAttr(tp, 1.0); 17085b8d03eSwren romano if (tp.isa<IndexType>()) 17185b8d03eSwren romano return builder.getIndexAttr(1); 17285b8d03eSwren romano if (auto intTp = tp.dyn_cast<IntegerType>()) 17385b8d03eSwren romano return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 17485b8d03eSwren romano if (tp.isa<RankedTensorType, VectorType>()) { 17585b8d03eSwren romano auto shapedTp = tp.cast<ShapedType>(); 17685b8d03eSwren romano if (auto one = getOneAttr(builder, shapedTp.getElementType())) 17785b8d03eSwren romano return DenseElementsAttr::get(shapedTp, one); 17885b8d03eSwren romano } 17985b8d03eSwren romano llvm_unreachable("Unsupported attribute type"); 18085b8d03eSwren romano } 18185b8d03eSwren romano 18285b8d03eSwren romano Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 18385b8d03eSwren romano Value v) { 18485b8d03eSwren romano Type tp = v.getType(); 18585b8d03eSwren romano Value zero = constantZero(builder, loc, tp); 18685b8d03eSwren romano if (tp.isa<FloatType>()) 18785b8d03eSwren romano return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 18885b8d03eSwren romano zero); 18985b8d03eSwren romano if (tp.isIntOrIndex()) 19085b8d03eSwren romano return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 19185b8d03eSwren romano zero); 192*28b6d412SAart Bik if (tp.dyn_cast<ComplexType>()) 193*28b6d412SAart Bik return builder.create<complex::NotEqualOp>(loc, v, zero); 19485b8d03eSwren romano llvm_unreachable("Non-numeric type"); 19585b8d03eSwren romano } 196