//===- TypeFromLLVM.cpp - type translation from LLVM to MLIR IR -===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/TypeFromLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Type.h" using namespace mlir; namespace mlir { namespace LLVM { namespace detail { /// Support for translating LLVM IR types to MLIR LLVM dialect types. class TypeFromLLVMIRTranslatorImpl { public: /// Constructs a class creating types in the given MLIR context. TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} /// Translates the given type. Type translateType(llvm::Type *type) { if (knownTranslations.count(type)) return knownTranslations.lookup(type); Type translated = llvm::TypeSwitch(type) .Case( [this](auto *type) { return this->translate(type); }) .Default([this](llvm::Type *type) { return translatePrimitiveType(type); }); knownTranslations.try_emplace(type, translated); return translated; } private: /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, /// type. Type translatePrimitiveType(llvm::Type *type) { if (type->isVoidTy()) return LLVM::LLVMVoidType::get(&context); if (type->isHalfTy()) return Float16Type::get(&context); if (type->isBFloatTy()) return BFloat16Type::get(&context); if (type->isFloatTy()) return Float32Type::get(&context); if (type->isDoubleTy()) return Float64Type::get(&context); if (type->isFP128Ty()) return Float128Type::get(&context); if (type->isX86_FP80Ty()) return Float80Type::get(&context); if (type->isPPC_FP128Ty()) return LLVM::LLVMPPCFP128Type::get(&context); if (type->isX86_MMXTy()) return LLVM::LLVMX86MMXType::get(&context); if (type->isLabelTy()) return LLVM::LLVMLabelType::get(&context); if (type->isMetadataTy()) return LLVM::LLVMMetadataType::get(&context); if (type->isTokenTy()) return LLVM::LLVMTokenType::get(&context); llvm_unreachable("not a primitive type"); } /// Translates the given array type. Type translate(llvm::ArrayType *type) { return LLVM::LLVMArrayType::get(translateType(type->getElementType()), type->getNumElements()); } /// Translates the given function type. Type translate(llvm::FunctionType *type) { SmallVector paramTypes; translateTypes(type->params(), paramTypes); return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), paramTypes, type->isVarArg()); } /// Translates the given integer type. Type translate(llvm::IntegerType *type) { return IntegerType::get(&context, type->getBitWidth()); } /// Translates the given pointer type. Type translate(llvm::PointerType *type) { if (type->isOpaque()) return LLVM::LLVMPointerType::get(&context, type->getAddressSpace()); return LLVM::LLVMPointerType::get( translateType(type->getNonOpaquePointerElementType()), type->getAddressSpace()); } /// Translates the given structure type. Type translate(llvm::StructType *type) { SmallVector subtypes; if (type->isLiteral()) { translateTypes(type->subtypes(), subtypes); return LLVM::LLVMStructType::getLiteral(&context, subtypes, type->isPacked()); } if (type->isOpaque()) return LLVM::LLVMStructType::getOpaque(type->getName(), &context); LLVM::LLVMStructType translated = LLVM::LLVMStructType::getIdentified(&context, type->getName()); knownTranslations.try_emplace(type, translated); translateTypes(type->subtypes(), subtypes); LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); assert(succeeded(bodySet) && "could not set the body of an identified struct"); (void)bodySet; return translated; } /// Translates the given fixed-vector type. Type translate(llvm::FixedVectorType *type) { return LLVM::getFixedVectorType(translateType(type->getElementType()), type->getNumElements()); } /// Translates the given scalable-vector type. Type translate(llvm::ScalableVectorType *type) { return LLVM::LLVMScalableVectorType::get( translateType(type->getElementType()), type->getMinNumElements()); } /// Translates a list of types. void translateTypes(ArrayRef types, SmallVectorImpl &result) { result.reserve(result.size() + types.size()); for (llvm::Type *type : types) result.push_back(translateType(type)); } /// Map of known translations. Serves as a cache and as recursion stopper for /// translating recursive structs. llvm::DenseMap knownTranslations; /// The context in which MLIR types are created. MLIRContext &context; }; } // namespace detail } // namespace LLVM } // namespace mlir LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context) : impl(new detail::TypeFromLLVMIRTranslatorImpl(context)) {} LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() = default; Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { return impl->translateType(type); }