1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the translation between an MLIR LLVM dialect module and 10 // the corresponding LLVMIR module. It only handles core LLVM IR operations. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 15 16 #include "DebugTranslation.h" 17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 18 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 19 #include "mlir/IR/Attributes.h" 20 #include "mlir/IR/BuiltinOps.h" 21 #include "mlir/IR/BuiltinTypes.h" 22 #include "mlir/IR/RegionGraphTraits.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Target/LLVMIR/TypeTranslation.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 27 #include "llvm/ADT/PostOrderIterator.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 30 #include "llvm/IR/BasicBlock.h" 31 #include "llvm/IR/CFG.h" 32 #include "llvm/IR/Constants.h" 33 #include "llvm/IR/DerivedTypes.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/InlineAsm.h" 36 #include "llvm/IR/LLVMContext.h" 37 #include "llvm/IR/MDBuilder.h" 38 #include "llvm/IR/Module.h" 39 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 40 #include "llvm/Transforms/Utils/Cloning.h" 41 42 using namespace mlir; 43 using namespace mlir::LLVM; 44 using namespace mlir::LLVM::detail; 45 46 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" 47 48 /// Builds a constant of a sequential LLVM type `type`, potentially containing 49 /// other sequential types recursively, from the individual constant values 50 /// provided in `constants`. `shape` contains the number of elements in nested 51 /// sequential types. Reports errors at `loc` and returns nullptr on error. 52 static llvm::Constant * 53 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants, 54 ArrayRef<int64_t> shape, llvm::Type *type, 55 Location loc) { 56 if (shape.empty()) { 57 llvm::Constant *result = constants.front(); 58 constants = constants.drop_front(); 59 return result; 60 } 61 62 llvm::Type *elementType; 63 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) { 64 elementType = arrayTy->getElementType(); 65 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) { 66 elementType = vectorTy->getElementType(); 67 } else { 68 emitError(loc) << "expected sequential LLVM types wrapping a scalar"; 69 return nullptr; 70 } 71 72 SmallVector<llvm::Constant *, 8> nested; 73 nested.reserve(shape.front()); 74 for (int64_t i = 0; i < shape.front(); ++i) { 75 nested.push_back(buildSequentialConstant(constants, shape.drop_front(), 76 elementType, loc)); 77 if (!nested.back()) 78 return nullptr; 79 } 80 81 if (shape.size() == 1 && type->isVectorTy()) 82 return llvm::ConstantVector::get(nested); 83 return llvm::ConstantArray::get( 84 llvm::ArrayType::get(elementType, shape.front()), nested); 85 } 86 87 /// Returns the first non-sequential type nested in sequential types. 88 static llvm::Type *getInnermostElementType(llvm::Type *type) { 89 do { 90 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) { 91 type = arrayTy->getElementType(); 92 } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) { 93 type = vectorTy->getElementType(); 94 } else { 95 return type; 96 } 97 } while (true); 98 } 99 100 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. 101 /// This currently supports integer, floating point, splat and dense element 102 /// attributes and combinations thereof. In case of error, report it to `loc` 103 /// and return nullptr. 104 llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, 105 Attribute attr, 106 Location loc) { 107 if (!attr) 108 return llvm::UndefValue::get(llvmType); 109 if (llvmType->isStructTy()) { 110 emitError(loc, "struct types are not supported in constants"); 111 return nullptr; 112 } 113 // For integer types, we allow a mismatch in sizes as the index type in 114 // MLIR might have a different size than the index type in the LLVM module. 115 if (auto intAttr = attr.dyn_cast<IntegerAttr>()) 116 return llvm::ConstantInt::get( 117 llvmType, 118 intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); 119 if (auto floatAttr = attr.dyn_cast<FloatAttr>()) 120 return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); 121 if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>()) 122 return llvm::ConstantExpr::getBitCast(lookupFunction(funcAttr.getValue()), 123 llvmType); 124 if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { 125 llvm::Type *elementType; 126 uint64_t numElements; 127 if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) { 128 elementType = arrayTy->getElementType(); 129 numElements = arrayTy->getNumElements(); 130 } else { 131 auto *vectorTy = cast<llvm::FixedVectorType>(llvmType); 132 elementType = vectorTy->getElementType(); 133 numElements = vectorTy->getNumElements(); 134 } 135 // Splat value is a scalar. Extract it only if the element type is not 136 // another sequence type. The recursion terminates because each step removes 137 // one outer sequential type. 138 bool elementTypeSequential = 139 isa<llvm::ArrayType, llvm::VectorType>(elementType); 140 llvm::Constant *child = getLLVMConstant( 141 elementType, 142 elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc); 143 if (!child) 144 return nullptr; 145 if (llvmType->isVectorTy()) 146 return llvm::ConstantVector::getSplat( 147 llvm::ElementCount::get(numElements, /*Scalable=*/false), child); 148 if (llvmType->isArrayTy()) { 149 auto *arrayType = llvm::ArrayType::get(elementType, numElements); 150 SmallVector<llvm::Constant *, 8> constants(numElements, child); 151 return llvm::ConstantArray::get(arrayType, constants); 152 } 153 } 154 155 if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) { 156 assert(elementsAttr.getType().hasStaticShape()); 157 assert(elementsAttr.getNumElements() != 0 && 158 "unexpected empty elements attribute"); 159 assert(!elementsAttr.getType().getShape().empty() && 160 "unexpected empty elements attribute shape"); 161 162 SmallVector<llvm::Constant *, 8> constants; 163 constants.reserve(elementsAttr.getNumElements()); 164 llvm::Type *innermostType = getInnermostElementType(llvmType); 165 for (auto n : elementsAttr.getValues<Attribute>()) { 166 constants.push_back(getLLVMConstant(innermostType, n, loc)); 167 if (!constants.back()) 168 return nullptr; 169 } 170 ArrayRef<llvm::Constant *> constantsRef = constants; 171 llvm::Constant *result = buildSequentialConstant( 172 constantsRef, elementsAttr.getType().getShape(), llvmType, loc); 173 assert(constantsRef.empty() && "did not consume all elemental constants"); 174 return result; 175 } 176 177 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 178 return llvm::ConstantDataArray::get( 179 llvmModule->getContext(), ArrayRef<char>{stringAttr.getValue().data(), 180 stringAttr.getValue().size()}); 181 } 182 emitError(loc, "unsupported constant value"); 183 return nullptr; 184 } 185 186 /// Convert MLIR integer comparison predicate to LLVM IR comparison predicate. 187 static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { 188 switch (p) { 189 case LLVM::ICmpPredicate::eq: 190 return llvm::CmpInst::Predicate::ICMP_EQ; 191 case LLVM::ICmpPredicate::ne: 192 return llvm::CmpInst::Predicate::ICMP_NE; 193 case LLVM::ICmpPredicate::slt: 194 return llvm::CmpInst::Predicate::ICMP_SLT; 195 case LLVM::ICmpPredicate::sle: 196 return llvm::CmpInst::Predicate::ICMP_SLE; 197 case LLVM::ICmpPredicate::sgt: 198 return llvm::CmpInst::Predicate::ICMP_SGT; 199 case LLVM::ICmpPredicate::sge: 200 return llvm::CmpInst::Predicate::ICMP_SGE; 201 case LLVM::ICmpPredicate::ult: 202 return llvm::CmpInst::Predicate::ICMP_ULT; 203 case LLVM::ICmpPredicate::ule: 204 return llvm::CmpInst::Predicate::ICMP_ULE; 205 case LLVM::ICmpPredicate::ugt: 206 return llvm::CmpInst::Predicate::ICMP_UGT; 207 case LLVM::ICmpPredicate::uge: 208 return llvm::CmpInst::Predicate::ICMP_UGE; 209 } 210 llvm_unreachable("incorrect comparison predicate"); 211 } 212 213 static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { 214 switch (p) { 215 case LLVM::FCmpPredicate::_false: 216 return llvm::CmpInst::Predicate::FCMP_FALSE; 217 case LLVM::FCmpPredicate::oeq: 218 return llvm::CmpInst::Predicate::FCMP_OEQ; 219 case LLVM::FCmpPredicate::ogt: 220 return llvm::CmpInst::Predicate::FCMP_OGT; 221 case LLVM::FCmpPredicate::oge: 222 return llvm::CmpInst::Predicate::FCMP_OGE; 223 case LLVM::FCmpPredicate::olt: 224 return llvm::CmpInst::Predicate::FCMP_OLT; 225 case LLVM::FCmpPredicate::ole: 226 return llvm::CmpInst::Predicate::FCMP_OLE; 227 case LLVM::FCmpPredicate::one: 228 return llvm::CmpInst::Predicate::FCMP_ONE; 229 case LLVM::FCmpPredicate::ord: 230 return llvm::CmpInst::Predicate::FCMP_ORD; 231 case LLVM::FCmpPredicate::ueq: 232 return llvm::CmpInst::Predicate::FCMP_UEQ; 233 case LLVM::FCmpPredicate::ugt: 234 return llvm::CmpInst::Predicate::FCMP_UGT; 235 case LLVM::FCmpPredicate::uge: 236 return llvm::CmpInst::Predicate::FCMP_UGE; 237 case LLVM::FCmpPredicate::ult: 238 return llvm::CmpInst::Predicate::FCMP_ULT; 239 case LLVM::FCmpPredicate::ule: 240 return llvm::CmpInst::Predicate::FCMP_ULE; 241 case LLVM::FCmpPredicate::une: 242 return llvm::CmpInst::Predicate::FCMP_UNE; 243 case LLVM::FCmpPredicate::uno: 244 return llvm::CmpInst::Predicate::FCMP_UNO; 245 case LLVM::FCmpPredicate::_true: 246 return llvm::CmpInst::Predicate::FCMP_TRUE; 247 } 248 llvm_unreachable("incorrect comparison predicate"); 249 } 250 251 static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) { 252 switch (op) { 253 case LLVM::AtomicBinOp::xchg: 254 return llvm::AtomicRMWInst::BinOp::Xchg; 255 case LLVM::AtomicBinOp::add: 256 return llvm::AtomicRMWInst::BinOp::Add; 257 case LLVM::AtomicBinOp::sub: 258 return llvm::AtomicRMWInst::BinOp::Sub; 259 case LLVM::AtomicBinOp::_and: 260 return llvm::AtomicRMWInst::BinOp::And; 261 case LLVM::AtomicBinOp::nand: 262 return llvm::AtomicRMWInst::BinOp::Nand; 263 case LLVM::AtomicBinOp::_or: 264 return llvm::AtomicRMWInst::BinOp::Or; 265 case LLVM::AtomicBinOp::_xor: 266 return llvm::AtomicRMWInst::BinOp::Xor; 267 case LLVM::AtomicBinOp::max: 268 return llvm::AtomicRMWInst::BinOp::Max; 269 case LLVM::AtomicBinOp::min: 270 return llvm::AtomicRMWInst::BinOp::Min; 271 case LLVM::AtomicBinOp::umax: 272 return llvm::AtomicRMWInst::BinOp::UMax; 273 case LLVM::AtomicBinOp::umin: 274 return llvm::AtomicRMWInst::BinOp::UMin; 275 case LLVM::AtomicBinOp::fadd: 276 return llvm::AtomicRMWInst::BinOp::FAdd; 277 case LLVM::AtomicBinOp::fsub: 278 return llvm::AtomicRMWInst::BinOp::FSub; 279 } 280 llvm_unreachable("incorrect atomic binary operator"); 281 } 282 283 static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) { 284 switch (ordering) { 285 case LLVM::AtomicOrdering::not_atomic: 286 return llvm::AtomicOrdering::NotAtomic; 287 case LLVM::AtomicOrdering::unordered: 288 return llvm::AtomicOrdering::Unordered; 289 case LLVM::AtomicOrdering::monotonic: 290 return llvm::AtomicOrdering::Monotonic; 291 case LLVM::AtomicOrdering::acquire: 292 return llvm::AtomicOrdering::Acquire; 293 case LLVM::AtomicOrdering::release: 294 return llvm::AtomicOrdering::Release; 295 case LLVM::AtomicOrdering::acq_rel: 296 return llvm::AtomicOrdering::AcquireRelease; 297 case LLVM::AtomicOrdering::seq_cst: 298 return llvm::AtomicOrdering::SequentiallyConsistent; 299 } 300 llvm_unreachable("incorrect atomic ordering"); 301 } 302 303 ModuleTranslation::ModuleTranslation(Operation *module, 304 std::unique_ptr<llvm::Module> llvmModule) 305 : mlirModule(module), llvmModule(std::move(llvmModule)), 306 debugTranslation( 307 std::make_unique<DebugTranslation>(module, *this->llvmModule)), 308 ompDialect(module->getContext()->getLoadedDialect("omp")), 309 typeTranslator(this->llvmModule->getContext()) { 310 assert(satisfiesLLVMModule(mlirModule) && 311 "mlirModule should honor LLVM's module semantics."); 312 } 313 ModuleTranslation::~ModuleTranslation() { 314 if (ompBuilder) 315 ompBuilder->finalize(); 316 } 317 318 /// Get the SSA value passed to the current block from the terminator operation 319 /// of its predecessor. 320 static Value getPHISourceValue(Block *current, Block *pred, 321 unsigned numArguments, unsigned index) { 322 Operation &terminator = *pred->getTerminator(); 323 if (isa<LLVM::BrOp>(terminator)) 324 return terminator.getOperand(index); 325 326 SuccessorRange successors = terminator.getSuccessors(); 327 assert(std::adjacent_find(successors.begin(), successors.end()) == 328 successors.end() && 329 "successors with arguments in LLVM branches must be different blocks"); 330 (void)successors; 331 332 // For instructions that branch based on a condition value, we need to take 333 // the operands for the branch that was taken. 334 if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) { 335 // For conditional branches, we take the operands from either the "true" or 336 // the "false" branch. 337 return condBranchOp.getSuccessor(0) == current 338 ? condBranchOp.trueDestOperands()[index] 339 : condBranchOp.falseDestOperands()[index]; 340 } 341 342 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) { 343 // For switches, we take the operands from either the default case, or from 344 // the case branch that was taken. 345 if (switchOp.defaultDestination() == current) 346 return switchOp.defaultOperands()[index]; 347 for (auto i : llvm::enumerate(switchOp.caseDestinations())) 348 if (i.value() == current) 349 return switchOp.getCaseOperands(i.index())[index]; 350 } 351 352 llvm_unreachable("only branch or switch operations can be terminators of a " 353 "block that has successors"); 354 } 355 356 /// Connect the PHI nodes to the results of preceding blocks. 357 template <typename T> 358 static void connectPHINodes(T &func, const ModuleTranslation &state) { 359 // Skip the first block, it cannot be branched to and its arguments correspond 360 // to the arguments of the LLVM function. 361 for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { 362 Block *bb = &*it; 363 llvm::BasicBlock *llvmBB = state.lookupBlock(bb); 364 auto phis = llvmBB->phis(); 365 auto numArguments = bb->getNumArguments(); 366 assert(numArguments == std::distance(phis.begin(), phis.end())); 367 for (auto &numberedPhiNode : llvm::enumerate(phis)) { 368 auto &phiNode = numberedPhiNode.value(); 369 unsigned index = numberedPhiNode.index(); 370 for (auto *pred : bb->getPredecessors()) { 371 // Find the LLVM IR block that contains the converted terminator 372 // instruction and use it in the PHI node. Note that this block is not 373 // necessarily the same as state.lookupBlock(pred), some operations 374 // (in particular, OpenMP operations using OpenMPIRBuilder) may have 375 // split the blocks. 376 llvm::Instruction *terminator = 377 state.lookupBranch(pred->getTerminator()); 378 assert(terminator && "missing the mapping for a terminator"); 379 phiNode.addIncoming( 380 state.lookupValue(getPHISourceValue(bb, pred, numArguments, index)), 381 terminator->getParent()); 382 } 383 } 384 } 385 } 386 387 /// Sort function blocks topologically. 388 template <typename T> 389 static llvm::SetVector<Block *> topologicalSort(T &f) { 390 // For each block that has not been visited yet (i.e. that has no 391 // predecessors), add it to the list as well as its successors. 392 llvm::SetVector<Block *> blocks; 393 for (Block &b : f) { 394 if (blocks.count(&b) == 0) { 395 llvm::ReversePostOrderTraversal<Block *> traversal(&b); 396 blocks.insert(traversal.begin(), traversal.end()); 397 } 398 } 399 assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); 400 401 return blocks; 402 } 403 404 /// Convert the OpenMP parallel Operation to LLVM IR. 405 LogicalResult 406 ModuleTranslation::convertOmpParallel(Operation &opInst, 407 llvm::IRBuilder<> &builder) { 408 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 409 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 410 // relying on captured variables. 411 LogicalResult bodyGenStatus = success(); 412 413 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 414 llvm::BasicBlock &continuationBlock) { 415 // ParallelOp has only one region associated with it. 416 auto ®ion = cast<omp::ParallelOp>(opInst).getRegion(); 417 convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(), 418 continuationBlock, builder, bodyGenStatus); 419 }; 420 421 // TODO: Perform appropriate actions according to the data-sharing 422 // attribute (shared, private, firstprivate, ...) of variables. 423 // Currently defaults to shared. 424 auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 425 llvm::Value &, llvm::Value &vPtr, 426 llvm::Value *&replacementValue) -> InsertPointTy { 427 replacementValue = &vPtr; 428 429 return codeGenIP; 430 }; 431 432 // TODO: Perform finalization actions for variables. This has to be 433 // called for variables which have destructors/finalizers. 434 auto finiCB = [&](InsertPointTy codeGenIP) {}; 435 436 llvm::Value *ifCond = nullptr; 437 if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var()) 438 ifCond = lookupValue(ifExprVar); 439 llvm::Value *numThreads = nullptr; 440 if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var()) 441 numThreads = lookupValue(numThreadsVar); 442 llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default; 443 if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val()) 444 pbKind = llvm::omp::getProcBindKind(bind.getValue()); 445 // TODO: Is the Parallel construct cancellable? 446 bool isCancellable = false; 447 // TODO: Determine the actual alloca insertion point, e.g., the function 448 // entry or the alloca insertion point as provided by the body callback 449 // above. 450 llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP()); 451 if (failed(bodyGenStatus)) 452 return failure(); 453 builder.restoreIP( 454 ompBuilder->createParallel(builder, allocaIP, bodyGenCB, privCB, finiCB, 455 ifCond, numThreads, pbKind, isCancellable)); 456 return success(); 457 } 458 459 void ModuleTranslation::convertOmpOpRegions( 460 Region ®ion, StringRef blockName, 461 llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock, 462 llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) { 463 llvm::LLVMContext &llvmContext = builder.getContext(); 464 for (Block &bb : region) { 465 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( 466 llvmContext, blockName, builder.GetInsertBlock()->getParent()); 467 mapBlock(&bb, llvmBB); 468 } 469 470 llvm::Instruction *sourceTerminator = sourceBlock.getTerminator(); 471 472 // Convert blocks one by one in topological order to ensure 473 // defs are converted before uses. 474 llvm::SetVector<Block *> blocks = topologicalSort(region); 475 for (Block *bb : blocks) { 476 llvm::BasicBlock *llvmBB = lookupBlock(bb); 477 // Retarget the branch of the entry block to the entry block of the 478 // converted region (regions are single-entry). 479 if (bb->isEntryBlock()) { 480 assert(sourceTerminator->getNumSuccessors() == 1 && 481 "provided entry block has multiple successors"); 482 assert(sourceTerminator->getSuccessor(0) == &continuationBlock && 483 "ContinuationBlock is not the successor of the entry block"); 484 sourceTerminator->setSuccessor(0, llvmBB); 485 } 486 487 llvm::IRBuilder<>::InsertPointGuard guard(builder); 488 if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) { 489 bodyGenStatus = failure(); 490 return; 491 } 492 493 // Special handling for `omp.yield` and `omp.terminator` (we may have more 494 // than one): they return the control to the parent OpenMP dialect operation 495 // so replace them with the branch to the continuation block. We handle this 496 // here to avoid relying inter-function communication through the 497 // ModuleTranslation class to set up the correct insertion point. This is 498 // also consistent with MLIR's idiom of handling special region terminators 499 // in the same code that handles the region-owning operation. 500 if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator())) 501 builder.CreateBr(&continuationBlock); 502 } 503 // Finally, after all blocks have been traversed and values mapped, 504 // connect the PHI nodes to the results of preceding blocks. 505 connectPHINodes(region, *this); 506 } 507 508 LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst, 509 llvm::IRBuilder<> &builder) { 510 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 511 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 512 // relying on captured variables. 513 LogicalResult bodyGenStatus = success(); 514 515 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 516 llvm::BasicBlock &continuationBlock) { 517 // MasterOp has only one region associated with it. 518 auto ®ion = cast<omp::MasterOp>(opInst).getRegion(); 519 convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(), 520 continuationBlock, builder, bodyGenStatus); 521 }; 522 523 // TODO: Perform finalization actions for variables. This has to be 524 // called for variables which have destructors/finalizers. 525 auto finiCB = [&](InsertPointTy codeGenIP) {}; 526 527 builder.restoreIP(ompBuilder->createMaster(builder, bodyGenCB, finiCB)); 528 return success(); 529 } 530 531 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. 532 LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst, 533 llvm::IRBuilder<> &builder) { 534 auto loop = cast<omp::WsLoopOp>(opInst); 535 // TODO: this should be in the op verifier instead. 536 if (loop.lowerBound().empty()) 537 return failure(); 538 539 if (loop.getNumLoops() != 1) 540 return opInst.emitOpError("collapsed loops not yet supported"); 541 542 if (loop.schedule_val().hasValue() && 543 omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) != 544 omp::ClauseScheduleKind::Static) 545 return opInst.emitOpError( 546 "only static (default) loop schedule is currently supported"); 547 548 // Find the loop configuration. 549 llvm::Value *lowerBound = lookupValue(loop.lowerBound()[0]); 550 llvm::Value *upperBound = lookupValue(loop.upperBound()[0]); 551 llvm::Value *step = lookupValue(loop.step()[0]); 552 llvm::Type *ivType = step->getType(); 553 llvm::Value *chunk = loop.schedule_chunk_var() 554 ? lookupValue(loop.schedule_chunk_var()) 555 : llvm::ConstantInt::get(ivType, 1); 556 557 // Set up the source location value for OpenMP runtime. 558 llvm::DISubprogram *subprogram = 559 builder.GetInsertBlock()->getParent()->getSubprogram(); 560 const llvm::DILocation *diLoc = 561 debugTranslation->translateLoc(opInst.getLoc(), subprogram); 562 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), 563 llvm::DebugLoc(diLoc)); 564 565 // Generator of the canonical loop body. Produces an SESE region of basic 566 // blocks. 567 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 568 // relying on captured variables. 569 LogicalResult bodyGenStatus = success(); 570 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { 571 llvm::IRBuilder<>::InsertPointGuard guard(builder); 572 573 // Make sure further conversions know about the induction variable. 574 mapValue(loop.getRegion().front().getArgument(0), iv); 575 576 llvm::BasicBlock *entryBlock = ip.getBlock(); 577 llvm::BasicBlock *exitBlock = 578 entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit"); 579 580 // Convert the body of the loop. 581 convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock, 582 *exitBlock, builder, bodyGenStatus); 583 }; 584 585 // Delegate actual loop construction to the OpenMP IRBuilder. 586 // TODO: this currently assumes WsLoop is semantically similar to SCF loop, 587 // i.e. it has a positive step, uses signed integer semantics. Reconsider 588 // this code when WsLoop clearly supports more cases. 589 llvm::BasicBlock *insertBlock = builder.GetInsertBlock(); 590 llvm::CanonicalLoopInfo *loopInfo = ompBuilder->createCanonicalLoop( 591 ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true, 592 /*InclusiveStop=*/loop.inclusive()); 593 if (failed(bodyGenStatus)) 594 return failure(); 595 596 // TODO: get the alloca insertion point from the parallel operation builder. 597 // If we insert the at the top of the current function, they will be passed as 598 // extra arguments into the function the parallel operation builder outlines. 599 // Put them at the start of the current block for now. 600 llvm::OpenMPIRBuilder::InsertPointTy allocaIP( 601 insertBlock, insertBlock->getFirstInsertionPt()); 602 loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP, 603 !loop.nowait(), chunk); 604 605 // Continue building IR after the loop. 606 builder.restoreIP(loopInfo->getAfterIP()); 607 return success(); 608 } 609 610 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR 611 /// (including OpenMP runtime calls). 612 LogicalResult 613 ModuleTranslation::convertOmpOperation(Operation &opInst, 614 llvm::IRBuilder<> &builder) { 615 if (!ompBuilder) { 616 ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule); 617 ompBuilder->initialize(); 618 } 619 return llvm::TypeSwitch<Operation *, LogicalResult>(&opInst) 620 .Case([&](omp::BarrierOp) { 621 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); 622 return success(); 623 }) 624 .Case([&](omp::TaskwaitOp) { 625 ompBuilder->createTaskwait(builder.saveIP()); 626 return success(); 627 }) 628 .Case([&](omp::TaskyieldOp) { 629 ompBuilder->createTaskyield(builder.saveIP()); 630 return success(); 631 }) 632 .Case([&](omp::FlushOp) { 633 // No support in Openmp runtime function (__kmpc_flush) to accept 634 // the argument list. 635 // OpenMP standard states the following: 636 // "An implementation may implement a flush with a list by ignoring 637 // the list, and treating it the same as a flush without a list." 638 // 639 // The argument list is discarded so that, flush with a list is treated 640 // same as a flush without a list. 641 ompBuilder->createFlush(builder.saveIP()); 642 return success(); 643 }) 644 .Case( 645 [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) 646 .Case([&](omp::MasterOp) { return convertOmpMaster(opInst, builder); }) 647 .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); }) 648 .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) { 649 // `yield` and `terminator` can be just omitted. The block structure was 650 // created in the function that handles their parent operation. 651 assert(op->getNumOperands() == 0 && 652 "unexpected OpenMP terminator with operands"); 653 return success(); 654 }) 655 .Default([&](Operation *inst) { 656 return inst->emitError("unsupported OpenMP operation: ") 657 << inst->getName(); 658 }); 659 } 660 661 static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { 662 using llvmFMF = llvm::FastMathFlags; 663 using FuncT = void (llvmFMF::*)(bool); 664 const std::pair<FastmathFlags, FuncT> handlers[] = { 665 // clang-format off 666 {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, 667 {FastmathFlags::ninf, &llvmFMF::setNoInfs}, 668 {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, 669 {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, 670 {FastmathFlags::contract, &llvmFMF::setAllowContract}, 671 {FastmathFlags::afn, &llvmFMF::setApproxFunc}, 672 {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, 673 {FastmathFlags::fast, &llvmFMF::setFast}, 674 // clang-format on 675 }; 676 llvm::FastMathFlags ret; 677 auto fmf = op.fastmathFlags(); 678 for (auto it : handlers) 679 if (bitEnumContains(fmf, it.first)) 680 (ret.*(it.second))(true); 681 return ret; 682 } 683 684 /// Given a single MLIR operation, create the corresponding LLVM IR operation 685 /// using the `builder`. LLVM IR Builder does not have a generic interface so 686 /// this has to be a long chain of `if`s calling different functions with a 687 /// different number of arguments. 688 LogicalResult ModuleTranslation::convertOperation(Operation &opInst, 689 llvm::IRBuilder<> &builder) { 690 auto extractPosition = [](ArrayAttr attr) { 691 SmallVector<unsigned, 4> position; 692 position.reserve(attr.size()); 693 for (Attribute v : attr) 694 position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue()); 695 return position; 696 }; 697 698 llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); 699 if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst)) 700 builder.setFastMathFlags(getFastmathFlags(fmf)); 701 702 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" 703 704 // Emit function calls. If the "callee" attribute is present, this is a 705 // direct function call and we also need to look up the remapped function 706 // itself. Otherwise, this is an indirect call and the callee is the first 707 // operand, look it up as a normal value. Return the llvm::Value representing 708 // the function result, which may be of llvm::VoidTy type. 709 auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { 710 auto operands = lookupValues(op.getOperands()); 711 ArrayRef<llvm::Value *> operandsRef(operands); 712 if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee")) 713 return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef); 714 auto *calleePtrType = 715 cast<llvm::PointerType>(operandsRef.front()->getType()); 716 auto *calleeType = 717 cast<llvm::FunctionType>(calleePtrType->getElementType()); 718 return builder.CreateCall(calleeType, operandsRef.front(), 719 operandsRef.drop_front()); 720 }; 721 722 // Emit calls. If the called function has a result, remap the corresponding 723 // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. 724 if (isa<LLVM::CallOp>(opInst)) { 725 llvm::Value *result = convertCall(opInst); 726 if (opInst.getNumResults() != 0) { 727 mapValue(opInst.getResult(0), result); 728 return success(); 729 } 730 // Check that LLVM call returns void for 0-result functions. 731 return success(result->getType()->isVoidTy()); 732 } 733 734 if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) { 735 // TODO: refactor function type creation which usually occurs in std-LLVM 736 // conversion. 737 SmallVector<Type, 8> operandTypes; 738 operandTypes.reserve(inlineAsmOp.operands().size()); 739 for (auto t : inlineAsmOp.operands().getTypes()) 740 operandTypes.push_back(t); 741 742 Type resultType; 743 if (inlineAsmOp.getNumResults() == 0) { 744 resultType = LLVM::LLVMVoidType::get(mlirModule->getContext()); 745 } else { 746 assert(inlineAsmOp.getNumResults() == 1); 747 resultType = inlineAsmOp.getResultTypes()[0]; 748 } 749 auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); 750 llvm::InlineAsm *inlineAsmInst = 751 inlineAsmOp.asm_dialect().hasValue() 752 ? llvm::InlineAsm::get( 753 static_cast<llvm::FunctionType *>(convertType(ft)), 754 inlineAsmOp.asm_string(), inlineAsmOp.constraints(), 755 inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(), 756 convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect())) 757 : llvm::InlineAsm::get( 758 static_cast<llvm::FunctionType *>(convertType(ft)), 759 inlineAsmOp.asm_string(), inlineAsmOp.constraints(), 760 inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack()); 761 llvm::Value *result = 762 builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands())); 763 if (opInst.getNumResults() != 0) 764 mapValue(opInst.getResult(0), result); 765 return success(); 766 } 767 768 if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) { 769 auto operands = lookupValues(opInst.getOperands()); 770 ArrayRef<llvm::Value *> operandsRef(operands); 771 if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) { 772 builder.CreateInvoke(lookupFunction(attr.getValue()), 773 lookupBlock(invOp.getSuccessor(0)), 774 lookupBlock(invOp.getSuccessor(1)), operandsRef); 775 } else { 776 auto *calleePtrType = 777 cast<llvm::PointerType>(operandsRef.front()->getType()); 778 auto *calleeType = 779 cast<llvm::FunctionType>(calleePtrType->getElementType()); 780 builder.CreateInvoke( 781 calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)), 782 lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front()); 783 } 784 return success(); 785 } 786 787 if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) { 788 llvm::Type *ty = convertType(lpOp.getType()); 789 llvm::LandingPadInst *lpi = 790 builder.CreateLandingPad(ty, lpOp.getNumOperands()); 791 792 // Add clauses 793 for (llvm::Value *operand : lookupValues(lpOp.getOperands())) { 794 // All operands should be constant - checked by verifier 795 if (auto *constOperand = dyn_cast<llvm::Constant>(operand)) 796 lpi->addClause(constOperand); 797 } 798 mapValue(lpOp.getResult(), lpi); 799 return success(); 800 } 801 802 // Emit branches. We need to look up the remapped blocks and ignore the block 803 // arguments that were transformed into PHI nodes. 804 if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) { 805 llvm::BranchInst *branch = 806 builder.CreateBr(lookupBlock(brOp.getSuccessor())); 807 mapBranch(&opInst, branch); 808 return success(); 809 } 810 if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) { 811 auto weights = condbrOp.branch_weights(); 812 llvm::MDNode *branchWeights = nullptr; 813 if (weights) { 814 // Map weight attributes to LLVM metadata. 815 auto trueWeight = 816 weights.getValue().getValue(0).cast<IntegerAttr>().getInt(); 817 auto falseWeight = 818 weights.getValue().getValue(1).cast<IntegerAttr>().getInt(); 819 branchWeights = 820 llvm::MDBuilder(llvmModule->getContext()) 821 .createBranchWeights(static_cast<uint32_t>(trueWeight), 822 static_cast<uint32_t>(falseWeight)); 823 } 824 llvm::BranchInst *branch = builder.CreateCondBr( 825 lookupValue(condbrOp.getOperand(0)), 826 lookupBlock(condbrOp.getSuccessor(0)), 827 lookupBlock(condbrOp.getSuccessor(1)), branchWeights); 828 mapBranch(&opInst, branch); 829 return success(); 830 } 831 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) { 832 llvm::MDNode *branchWeights = nullptr; 833 if (auto weights = switchOp.branch_weights()) { 834 llvm::SmallVector<uint32_t> weightValues; 835 weightValues.reserve(weights->size()); 836 for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>()) 837 weightValues.push_back(weight.getLimitedValue()); 838 branchWeights = llvm::MDBuilder(llvmModule->getContext()) 839 .createBranchWeights(weightValues); 840 } 841 842 llvm::SwitchInst *switchInst = 843 builder.CreateSwitch(lookupValue(switchOp.value()), 844 lookupBlock(switchOp.defaultDestination()), 845 switchOp.caseDestinations().size(), branchWeights); 846 847 auto *ty = 848 llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType())); 849 for (auto i : 850 llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(), 851 switchOp.caseDestinations())) 852 switchInst->addCase( 853 llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), 854 lookupBlock(std::get<1>(i))); 855 856 mapBranch(&opInst, switchInst); 857 return success(); 858 } 859 860 // Emit addressof. We need to look up the global value referenced by the 861 // operation and store it in the MLIR-to-LLVM value mapping. This does not 862 // emit any LLVM instruction. 863 if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) { 864 LLVM::GlobalOp global = addressOfOp.getGlobal(); 865 LLVM::LLVMFuncOp function = addressOfOp.getFunction(); 866 867 // The verifier should not have allowed this. 868 assert((global || function) && 869 "referencing an undefined global or function"); 870 871 mapValue(addressOfOp.getResult(), global 872 ? globalsMapping.lookup(global) 873 : lookupFunction(function.getName())); 874 return success(); 875 } 876 877 if (ompDialect && opInst.getDialect() == ompDialect) 878 return convertOmpOperation(opInst, builder); 879 880 return opInst.emitError("unsupported or non-LLVM operation: ") 881 << opInst.getName(); 882 } 883 884 /// Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes 885 /// to define values corresponding to the MLIR block arguments. These nodes 886 /// are not connected to the source basic blocks, which may not exist yet. Uses 887 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have 888 /// been created for `bb` and included in the block mapping. Inserts new 889 /// instructions at the end of the block and leaves `builder` in a state 890 /// suitable for further insertion into the end of the block. 891 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments, 892 llvm::IRBuilder<> &builder) { 893 builder.SetInsertPoint(lookupBlock(&bb)); 894 auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram(); 895 896 // Before traversing operations, make block arguments available through 897 // value remapping and PHI nodes, but do not add incoming edges for the PHI 898 // nodes just yet: those values may be defined by this or following blocks. 899 // This step is omitted if "ignoreArguments" is set. The arguments of the 900 // first block have been already made available through the remapping of 901 // LLVM function arguments. 902 if (!ignoreArguments) { 903 auto predecessors = bb.getPredecessors(); 904 unsigned numPredecessors = 905 std::distance(predecessors.begin(), predecessors.end()); 906 for (auto arg : bb.getArguments()) { 907 auto wrappedType = arg.getType(); 908 if (!isCompatibleType(wrappedType)) 909 return emitError(bb.front().getLoc(), 910 "block argument does not have an LLVM type"); 911 llvm::Type *type = convertType(wrappedType); 912 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); 913 mapValue(arg, phi); 914 } 915 } 916 917 // Traverse operations. 918 for (auto &op : bb) { 919 // Set the current debug location within the builder. 920 builder.SetCurrentDebugLocation( 921 debugTranslation->translateLoc(op.getLoc(), subprogram)); 922 923 if (failed(convertOperation(op, builder))) 924 return failure(); 925 } 926 927 return success(); 928 } 929 930 /// Create named global variables that correspond to llvm.mlir.global 931 /// definitions. 932 LogicalResult ModuleTranslation::convertGlobals() { 933 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { 934 llvm::Type *type = convertType(op.getType()); 935 llvm::Constant *cst = llvm::UndefValue::get(type); 936 if (op.getValueOrNull()) { 937 // String attributes are treated separately because they cannot appear as 938 // in-function constants and are thus not supported by getLLVMConstant. 939 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 940 cst = llvm::ConstantDataArray::getString( 941 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); 942 type = cst->getType(); 943 } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(), 944 op.getLoc()))) { 945 return failure(); 946 } 947 } else if (Block *initializer = op.getInitializerBlock()) { 948 llvm::IRBuilder<> builder(llvmModule->getContext()); 949 for (auto &op : initializer->without_terminator()) { 950 if (failed(convertOperation(op, builder)) || 951 !isa<llvm::Constant>(lookupValue(op.getResult(0)))) 952 return emitError(op.getLoc(), "unemittable constant value"); 953 } 954 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator()); 955 cst = cast<llvm::Constant>(lookupValue(ret.getOperand(0))); 956 } 957 958 auto linkage = convertLinkageToLLVM(op.linkage()); 959 bool anyExternalLinkage = 960 ((linkage == llvm::GlobalVariable::ExternalLinkage && 961 isa<llvm::UndefValue>(cst)) || 962 linkage == llvm::GlobalVariable::ExternalWeakLinkage); 963 auto addrSpace = op.addr_space(); 964 auto *var = new llvm::GlobalVariable( 965 *llvmModule, type, op.constant(), linkage, 966 anyExternalLinkage ? nullptr : cst, op.sym_name(), 967 /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace); 968 969 globalsMapping.try_emplace(op, var); 970 } 971 972 return success(); 973 } 974 975 /// Attempts to add an attribute identified by `key`, optionally with the given 976 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the 977 /// attribute has a kind known to LLVM IR, create the attribute of this kind, 978 /// otherwise keep it as a string attribute. Performs additional checks for 979 /// attributes known to have or not have a value in order to avoid assertions 980 /// inside LLVM upon construction. 981 static LogicalResult checkedAddLLVMFnAttribute(Location loc, 982 llvm::Function *llvmFunc, 983 StringRef key, 984 StringRef value = StringRef()) { 985 auto kind = llvm::Attribute::getAttrKindFromName(key); 986 if (kind == llvm::Attribute::None) { 987 llvmFunc->addFnAttr(key, value); 988 return success(); 989 } 990 991 if (llvm::Attribute::doesAttrKindHaveArgument(kind)) { 992 if (value.empty()) 993 return emitError(loc) << "LLVM attribute '" << key << "' expects a value"; 994 995 int result; 996 if (!value.getAsInteger(/*Radix=*/0, result)) 997 llvmFunc->addFnAttr( 998 llvm::Attribute::get(llvmFunc->getContext(), kind, result)); 999 else 1000 llvmFunc->addFnAttr(key, value); 1001 return success(); 1002 } 1003 1004 if (!value.empty()) 1005 return emitError(loc) << "LLVM attribute '" << key 1006 << "' does not expect a value, found '" << value 1007 << "'"; 1008 1009 llvmFunc->addFnAttr(kind); 1010 return success(); 1011 } 1012 1013 /// Attaches the attributes listed in the given array attribute to `llvmFunc`. 1014 /// Reports error to `loc` if any and returns immediately. Expects `attributes` 1015 /// to be an array attribute containing either string attributes, treated as 1016 /// value-less LLVM attributes, or array attributes containing two string 1017 /// attributes, with the first string being the name of the corresponding LLVM 1018 /// attribute and the second string beings its value. Note that even integer 1019 /// attributes are expected to have their values expressed as strings. 1020 static LogicalResult 1021 forwardPassthroughAttributes(Location loc, Optional<ArrayAttr> attributes, 1022 llvm::Function *llvmFunc) { 1023 if (!attributes) 1024 return success(); 1025 1026 for (Attribute attr : *attributes) { 1027 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 1028 if (failed( 1029 checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) 1030 return failure(); 1031 continue; 1032 } 1033 1034 auto arrayAttr = attr.dyn_cast<ArrayAttr>(); 1035 if (!arrayAttr || arrayAttr.size() != 2) 1036 return emitError(loc) 1037 << "expected 'passthrough' to contain string or array attributes"; 1038 1039 auto keyAttr = arrayAttr[0].dyn_cast<StringAttr>(); 1040 auto valueAttr = arrayAttr[1].dyn_cast<StringAttr>(); 1041 if (!keyAttr || !valueAttr) 1042 return emitError(loc) 1043 << "expected arrays within 'passthrough' to contain two strings"; 1044 1045 if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(), 1046 valueAttr.getValue()))) 1047 return failure(); 1048 } 1049 return success(); 1050 } 1051 1052 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { 1053 // Clear the block, branch value mappings, they are only relevant within one 1054 // function. 1055 blockMapping.clear(); 1056 valueMapping.clear(); 1057 branchMapping.clear(); 1058 llvm::Function *llvmFunc = lookupFunction(func.getName()); 1059 1060 // Translate the debug information for this function. 1061 debugTranslation->translate(func, *llvmFunc); 1062 1063 // Add function arguments to the value remapping table. 1064 // If there was noalias info then we decorate each argument accordingly. 1065 unsigned int argIdx = 0; 1066 for (auto kvp : llvm::zip(func.getArguments(), llvmFunc->args())) { 1067 llvm::Argument &llvmArg = std::get<1>(kvp); 1068 BlockArgument mlirArg = std::get<0>(kvp); 1069 1070 if (auto attr = func.getArgAttrOfType<BoolAttr>( 1071 argIdx, LLVMDialect::getNoAliasAttrName())) { 1072 // NB: Attribute already verified to be boolean, so check if we can indeed 1073 // attach the attribute to this argument, based on its type. 1074 auto argTy = mlirArg.getType(); 1075 if (!argTy.isa<LLVM::LLVMPointerType>()) 1076 return func.emitError( 1077 "llvm.noalias attribute attached to LLVM non-pointer argument"); 1078 if (attr.getValue()) 1079 llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias); 1080 } 1081 1082 if (auto attr = func.getArgAttrOfType<IntegerAttr>( 1083 argIdx, LLVMDialect::getAlignAttrName())) { 1084 // NB: Attribute already verified to be int, so check if we can indeed 1085 // attach the attribute to this argument, based on its type. 1086 auto argTy = mlirArg.getType(); 1087 if (!argTy.isa<LLVM::LLVMPointerType>()) 1088 return func.emitError( 1089 "llvm.align attribute attached to LLVM non-pointer argument"); 1090 llvmArg.addAttrs( 1091 llvm::AttrBuilder().addAlignmentAttr(llvm::Align(attr.getInt()))); 1092 } 1093 1094 if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.sret")) { 1095 auto argTy = mlirArg.getType(); 1096 if (!argTy.isa<LLVM::LLVMPointerType>()) 1097 return func.emitError( 1098 "llvm.sret attribute attached to LLVM non-pointer argument"); 1099 llvmArg.addAttrs(llvm::AttrBuilder().addStructRetAttr( 1100 llvmArg.getType()->getPointerElementType())); 1101 } 1102 1103 if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.byval")) { 1104 auto argTy = mlirArg.getType(); 1105 if (!argTy.isa<LLVM::LLVMPointerType>()) 1106 return func.emitError( 1107 "llvm.byval attribute attached to LLVM non-pointer argument"); 1108 llvmArg.addAttrs(llvm::AttrBuilder().addByValAttr( 1109 llvmArg.getType()->getPointerElementType())); 1110 } 1111 1112 mapValue(mlirArg, &llvmArg); 1113 argIdx++; 1114 } 1115 1116 // Check the personality and set it. 1117 if (func.personality().hasValue()) { 1118 llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext()); 1119 if (llvm::Constant *pfunc = 1120 getLLVMConstant(ty, func.personalityAttr(), func.getLoc())) 1121 llvmFunc->setPersonalityFn(pfunc); 1122 } 1123 1124 // First, create all blocks so we can jump to them. 1125 llvm::LLVMContext &llvmContext = llvmFunc->getContext(); 1126 for (auto &bb : func) { 1127 auto *llvmBB = llvm::BasicBlock::Create(llvmContext); 1128 llvmBB->insertInto(llvmFunc); 1129 mapBlock(&bb, llvmBB); 1130 } 1131 1132 // Then, convert blocks one by one in topological order to ensure defs are 1133 // converted before uses. 1134 auto blocks = topologicalSort(func); 1135 for (Block *bb : blocks) { 1136 llvm::IRBuilder<> builder(llvmContext); 1137 if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) 1138 return failure(); 1139 } 1140 1141 // Finally, after all blocks have been traversed and values mapped, connect 1142 // the PHI nodes to the results of preceding blocks. 1143 connectPHINodes(func, *this); 1144 return success(); 1145 } 1146 1147 LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) { 1148 for (Operation &o : getModuleBody(m).getOperations()) 1149 if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) && 1150 !o.hasTrait<OpTrait::IsTerminator>()) 1151 return o.emitOpError("unsupported module-level operation"); 1152 return success(); 1153 } 1154 1155 LogicalResult ModuleTranslation::convertFunctionSignatures() { 1156 // Declare all functions first because there may be function calls that form a 1157 // call graph with cycles, or global initializers that reference functions. 1158 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 1159 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( 1160 function.getName(), 1161 cast<llvm::FunctionType>(convertType(function.getType()))); 1162 llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee()); 1163 llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage())); 1164 mapFunction(function.getName(), llvmFunc); 1165 1166 // Forward the pass-through attributes to LLVM. 1167 if (failed(forwardPassthroughAttributes(function.getLoc(), 1168 function.passthrough(), llvmFunc))) 1169 return failure(); 1170 } 1171 1172 return success(); 1173 } 1174 1175 LogicalResult ModuleTranslation::convertFunctions() { 1176 // Convert functions. 1177 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 1178 // Ignore external functions. 1179 if (function.isExternal()) 1180 continue; 1181 1182 if (failed(convertOneFunction(function))) 1183 return failure(); 1184 } 1185 1186 return success(); 1187 } 1188 1189 llvm::Type *ModuleTranslation::convertType(Type type) { 1190 return typeTranslator.translateType(type); 1191 } 1192 1193 /// A helper to look up remapped operands in the value remapping table.` 1194 SmallVector<llvm::Value *, 8> 1195 ModuleTranslation::lookupValues(ValueRange values) { 1196 SmallVector<llvm::Value *, 8> remapped; 1197 remapped.reserve(values.size()); 1198 for (Value v : values) 1199 remapped.push_back(lookupValue(v)); 1200 return remapped; 1201 } 1202 1203 std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule( 1204 Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { 1205 m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>(); 1206 auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext); 1207 if (auto dataLayoutAttr = 1208 m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName())) 1209 llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue()); 1210 if (auto targetTripleAttr = 1211 m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName())) 1212 llvmModule->setTargetTriple(targetTripleAttr.cast<StringAttr>().getValue()); 1213 1214 // Inject declarations for `malloc` and `free` functions that can be used in 1215 // memref allocation/deallocation coming from standard ops lowering. 1216 llvm::IRBuilder<> builder(llvmContext); 1217 llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(), 1218 builder.getInt64Ty()); 1219 llvmModule->getOrInsertFunction("free", builder.getVoidTy(), 1220 builder.getInt8PtrTy()); 1221 1222 return llvmModule; 1223 } 1224