185b8d03eSwren romano //===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// 285b8d03eSwren romano // 385b8d03eSwren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 485b8d03eSwren romano // See https://llvm.org/LICENSE.txt for license information. 585b8d03eSwren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 685b8d03eSwren romano // 785b8d03eSwren romano //===----------------------------------------------------------------------===// 885b8d03eSwren romano 985b8d03eSwren romano #include "CodegenUtils.h" 1085b8d03eSwren romano 1185b8d03eSwren romano #include "mlir/IR/Types.h" 1285b8d03eSwren romano #include "mlir/IR/Value.h" 1385b8d03eSwren romano 1485b8d03eSwren romano using namespace mlir; 1585b8d03eSwren romano using namespace mlir::sparse_tensor; 1685b8d03eSwren romano 1785b8d03eSwren romano //===----------------------------------------------------------------------===// 1885b8d03eSwren romano // ExecutionEngine/SparseTensorUtils helper functions. 1985b8d03eSwren romano //===----------------------------------------------------------------------===// 2085b8d03eSwren romano 2185b8d03eSwren romano OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { 2285b8d03eSwren romano switch (width) { 23*bc04a470Swren romano case 64: 2485b8d03eSwren romano return OverheadType::kU64; 2585b8d03eSwren romano case 32: 2685b8d03eSwren romano return OverheadType::kU32; 2785b8d03eSwren romano case 16: 2885b8d03eSwren romano return OverheadType::kU16; 2985b8d03eSwren romano case 8: 3085b8d03eSwren romano return OverheadType::kU8; 31*bc04a470Swren romano case 0: 32*bc04a470Swren romano return OverheadType::kIndex; 3385b8d03eSwren romano } 34*bc04a470Swren romano llvm_unreachable("Unsupported overhead bitwidth"); 3585b8d03eSwren romano } 3685b8d03eSwren romano 3785b8d03eSwren romano Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { 3885b8d03eSwren romano switch (ot) { 39*bc04a470Swren romano case OverheadType::kIndex: 40*bc04a470Swren romano return builder.getIndexType(); 4185b8d03eSwren romano case OverheadType::kU64: 4285b8d03eSwren romano return builder.getIntegerType(64); 4385b8d03eSwren romano case OverheadType::kU32: 4485b8d03eSwren romano return builder.getIntegerType(32); 4585b8d03eSwren romano case OverheadType::kU16: 4685b8d03eSwren romano return builder.getIntegerType(16); 4785b8d03eSwren romano case OverheadType::kU8: 4885b8d03eSwren romano return builder.getIntegerType(8); 4985b8d03eSwren romano } 5085b8d03eSwren romano llvm_unreachable("Unknown OverheadType"); 5185b8d03eSwren romano } 5285b8d03eSwren romano 5385b8d03eSwren romano Type mlir::sparse_tensor::getPointerOverheadType( 5485b8d03eSwren romano Builder &builder, const SparseTensorEncodingAttr &enc) { 55*bc04a470Swren romano return getOverheadType(builder, 56*bc04a470Swren romano overheadTypeEncoding(enc.getPointerBitWidth())); 5785b8d03eSwren romano } 5885b8d03eSwren romano 5985b8d03eSwren romano Type mlir::sparse_tensor::getIndexOverheadType( 6085b8d03eSwren romano Builder &builder, const SparseTensorEncodingAttr &enc) { 61*bc04a470Swren romano return getOverheadType(builder, overheadTypeEncoding(enc.getIndexBitWidth())); 6285b8d03eSwren romano } 6385b8d03eSwren romano 6485b8d03eSwren romano PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { 6585b8d03eSwren romano if (elemTp.isF64()) 6685b8d03eSwren romano return PrimaryType::kF64; 6785b8d03eSwren romano if (elemTp.isF32()) 6885b8d03eSwren romano return PrimaryType::kF32; 6985b8d03eSwren romano if (elemTp.isInteger(64)) 7085b8d03eSwren romano return PrimaryType::kI64; 7185b8d03eSwren romano if (elemTp.isInteger(32)) 7285b8d03eSwren romano return PrimaryType::kI32; 7385b8d03eSwren romano if (elemTp.isInteger(16)) 7485b8d03eSwren romano return PrimaryType::kI16; 7585b8d03eSwren romano if (elemTp.isInteger(8)) 7685b8d03eSwren romano return PrimaryType::kI8; 7785b8d03eSwren romano llvm_unreachable("Unknown primary type"); 7885b8d03eSwren romano } 7985b8d03eSwren romano 8085b8d03eSwren romano DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( 8185b8d03eSwren romano SparseTensorEncodingAttr::DimLevelType dlt) { 8285b8d03eSwren romano switch (dlt) { 8385b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Dense: 8485b8d03eSwren romano return DimLevelType::kDense; 8585b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Compressed: 8685b8d03eSwren romano return DimLevelType::kCompressed; 8785b8d03eSwren romano case SparseTensorEncodingAttr::DimLevelType::Singleton: 8885b8d03eSwren romano return DimLevelType::kSingleton; 8985b8d03eSwren romano } 9085b8d03eSwren romano llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); 9185b8d03eSwren romano } 9285b8d03eSwren romano 9385b8d03eSwren romano //===----------------------------------------------------------------------===// 9485b8d03eSwren romano // Misc code generators. 9585b8d03eSwren romano //===----------------------------------------------------------------------===// 9685b8d03eSwren romano 9785b8d03eSwren romano mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { 9885b8d03eSwren romano if (tp.isa<FloatType>()) 9985b8d03eSwren romano return builder.getFloatAttr(tp, 1.0); 10085b8d03eSwren romano if (tp.isa<IndexType>()) 10185b8d03eSwren romano return builder.getIndexAttr(1); 10285b8d03eSwren romano if (auto intTp = tp.dyn_cast<IntegerType>()) 10385b8d03eSwren romano return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); 10485b8d03eSwren romano if (tp.isa<RankedTensorType, VectorType>()) { 10585b8d03eSwren romano auto shapedTp = tp.cast<ShapedType>(); 10685b8d03eSwren romano if (auto one = getOneAttr(builder, shapedTp.getElementType())) 10785b8d03eSwren romano return DenseElementsAttr::get(shapedTp, one); 10885b8d03eSwren romano } 10985b8d03eSwren romano llvm_unreachable("Unsupported attribute type"); 11085b8d03eSwren romano } 11185b8d03eSwren romano 11285b8d03eSwren romano Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, 11385b8d03eSwren romano Value v) { 11485b8d03eSwren romano Type tp = v.getType(); 11585b8d03eSwren romano Value zero = constantZero(builder, loc, tp); 11685b8d03eSwren romano if (tp.isa<FloatType>()) 11785b8d03eSwren romano return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v, 11885b8d03eSwren romano zero); 11985b8d03eSwren romano if (tp.isIntOrIndex()) 12085b8d03eSwren romano return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v, 12185b8d03eSwren romano zero); 12285b8d03eSwren romano llvm_unreachable("Non-numeric type"); 12385b8d03eSwren romano } 124