//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "CodegenUtils.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" using namespace mlir; using namespace mlir::sparse_tensor; //===----------------------------------------------------------------------===// // ExecutionEngine/SparseTensorUtils helper functions. //===----------------------------------------------------------------------===// OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) { switch (width) { case 64: return OverheadType::kU64; case 32: return OverheadType::kU32; case 16: return OverheadType::kU16; case 8: return OverheadType::kU8; case 0: return OverheadType::kIndex; } llvm_unreachable("Unsupported overhead bitwidth"); } OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) { if (tp.isIndex()) return OverheadType::kIndex; if (auto intTp = tp.dyn_cast()) return overheadTypeEncoding(intTp.getWidth()); llvm_unreachable("Unknown overhead type"); } Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) { switch (ot) { case OverheadType::kIndex: return builder.getIndexType(); case OverheadType::kU64: return builder.getIntegerType(64); case OverheadType::kU32: return builder.getIntegerType(32); case OverheadType::kU16: return builder.getIntegerType(16); case OverheadType::kU8: return builder.getIntegerType(8); } llvm_unreachable("Unknown OverheadType"); } OverheadType mlir::sparse_tensor::pointerOverheadTypeEncoding( const SparseTensorEncodingAttr &enc) { return overheadTypeEncoding(enc.getPointerBitWidth()); } OverheadType mlir::sparse_tensor::indexOverheadTypeEncoding( const SparseTensorEncodingAttr &enc) { return overheadTypeEncoding(enc.getIndexBitWidth()); } Type mlir::sparse_tensor::getPointerOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { return getOverheadType(builder, pointerOverheadTypeEncoding(enc)); } Type mlir::sparse_tensor::getIndexOverheadType( Builder &builder, const SparseTensorEncodingAttr &enc) { return getOverheadType(builder, indexOverheadTypeEncoding(enc)); } // TODO: Adjust the naming convention for the constructors of // `OverheadType` so we can use the `FOREVERY_O` x-macro here instead // of `FOREVERY_FIXED_O`; to further reduce the possibility of typo bugs // or things getting out of sync. StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) { switch (ot) { case OverheadType::kIndex: return "0"; #define CASE(ONAME, O) \ case OverheadType::kU##ONAME: \ return #ONAME; FOREVERY_FIXED_O(CASE) #undef CASE } llvm_unreachable("Unknown OverheadType"); } StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) { return overheadTypeFunctionSuffix(overheadTypeEncoding(tp)); } PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) { if (elemTp.isF64()) return PrimaryType::kF64; if (elemTp.isF32()) return PrimaryType::kF32; if (elemTp.isF16()) return PrimaryType::kF16; if (elemTp.isBF16()) return PrimaryType::kBF16; if (elemTp.isInteger(64)) return PrimaryType::kI64; if (elemTp.isInteger(32)) return PrimaryType::kI32; if (elemTp.isInteger(16)) return PrimaryType::kI16; if (elemTp.isInteger(8)) return PrimaryType::kI8; if (auto complexTp = elemTp.dyn_cast()) { auto complexEltTp = complexTp.getElementType(); if (complexEltTp.isF64()) return PrimaryType::kC64; if (complexEltTp.isF32()) return PrimaryType::kC32; } llvm_unreachable("Unknown primary type"); } StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) { switch (pt) { #define CASE(VNAME, V) \ case PrimaryType::k##VNAME: \ return #VNAME; FOREVERY_V(CASE) #undef CASE } llvm_unreachable("Unknown PrimaryType"); } StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) { return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp)); } DimLevelType mlir::sparse_tensor::dimLevelTypeEncoding( SparseTensorEncodingAttr::DimLevelType dlt) { switch (dlt) { case SparseTensorEncodingAttr::DimLevelType::Dense: return DimLevelType::kDense; case SparseTensorEncodingAttr::DimLevelType::Compressed: return DimLevelType::kCompressed; case SparseTensorEncodingAttr::DimLevelType::Singleton: return DimLevelType::kSingleton; } llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType"); } //===----------------------------------------------------------------------===// // Misc code generators. //===----------------------------------------------------------------------===// mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { if (tp.isa()) return builder.getFloatAttr(tp, 1.0); if (tp.isa()) return builder.getIndexAttr(1); if (auto intTp = tp.dyn_cast()) return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1)); if (tp.isa()) { auto shapedTp = tp.cast(); if (auto one = getOneAttr(builder, shapedTp.getElementType())) return DenseElementsAttr::get(shapedTp, one); } llvm_unreachable("Unsupported attribute type"); } Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc, Value v) { Type tp = v.getType(); Value zero = constantZero(builder, loc, tp); if (tp.isa()) return builder.create(loc, arith::CmpFPredicate::UNE, v, zero); if (tp.isIntOrIndex()) return builder.create(loc, arith::CmpIPredicate::ne, v, zero); if (tp.dyn_cast()) return builder.create(loc, v, zero); llvm_unreachable("Non-numeric type"); }