1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the translation between an MLIR LLVM dialect module and 19 // the corresponding LLVMIR module. It only handles core LLVM IR operations. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 24 25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/Module.h" 28 #include "mlir/Support/LLVM.h" 29 30 #include "llvm/ADT/SetVector.h" 31 #include "llvm/IR/BasicBlock.h" 32 #include "llvm/IR/Constants.h" 33 #include "llvm/IR/DerivedTypes.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/LLVMContext.h" 36 #include "llvm/IR/Module.h" 37 #include "llvm/Transforms/Utils/Cloning.h" 38 39 namespace mlir { 40 namespace LLVM { 41 42 // Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. 43 // This currently supports integer, floating point, splat and dense element 44 // attributes and combinations thereof. In case of error, report it to `loc` 45 // and return nullptr. 46 llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, 47 Attribute attr, 48 Location loc) { 49 if (!attr) 50 return llvm::UndefValue::get(llvmType); 51 if (auto intAttr = attr.dyn_cast<IntegerAttr>()) 52 return llvm::ConstantInt::get(llvmType, intAttr.getValue()); 53 if (auto floatAttr = attr.dyn_cast<FloatAttr>()) 54 return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); 55 if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>()) 56 return functionMapping.lookup(funcAttr.getValue()); 57 if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { 58 auto *sequentialType = cast<llvm::SequentialType>(llvmType); 59 auto elementType = sequentialType->getElementType(); 60 uint64_t numElements = sequentialType->getNumElements(); 61 auto *child = getLLVMConstant(elementType, splatAttr.getSplatValue(), loc); 62 if (llvmType->isVectorTy()) 63 return llvm::ConstantVector::getSplat(numElements, child); 64 if (llvmType->isArrayTy()) { 65 auto arrayType = llvm::ArrayType::get(elementType, numElements); 66 SmallVector<llvm::Constant *, 8> constants(numElements, child); 67 return llvm::ConstantArray::get(arrayType, constants); 68 } 69 } 70 if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) { 71 auto *sequentialType = cast<llvm::SequentialType>(llvmType); 72 auto elementType = sequentialType->getElementType(); 73 uint64_t numElements = sequentialType->getNumElements(); 74 SmallVector<llvm::Constant *, 8> constants; 75 constants.reserve(numElements); 76 for (auto n : elementsAttr.getValues<Attribute>()) { 77 constants.push_back(getLLVMConstant(elementType, n, loc)); 78 if (!constants.back()) 79 return nullptr; 80 } 81 if (llvmType->isVectorTy()) 82 return llvm::ConstantVector::get(constants); 83 if (llvmType->isArrayTy()) { 84 auto arrayType = llvm::ArrayType::get(elementType, numElements); 85 return llvm::ConstantArray::get(arrayType, constants); 86 } 87 } 88 if (auto stringAttr = attr.dyn_cast<StringAttr>()) { 89 return llvm::ConstantDataArray::get( 90 llvmModule->getContext(), ArrayRef<char>{stringAttr.getValue().data(), 91 stringAttr.getValue().size()}); 92 } 93 emitError(loc, "unsupported constant value"); 94 return nullptr; 95 } 96 97 // Convert MLIR integer comparison predicate to LLVM IR comparison predicate. 98 static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) { 99 switch (p) { 100 case LLVM::ICmpPredicate::eq: 101 return llvm::CmpInst::Predicate::ICMP_EQ; 102 case LLVM::ICmpPredicate::ne: 103 return llvm::CmpInst::Predicate::ICMP_NE; 104 case LLVM::ICmpPredicate::slt: 105 return llvm::CmpInst::Predicate::ICMP_SLT; 106 case LLVM::ICmpPredicate::sle: 107 return llvm::CmpInst::Predicate::ICMP_SLE; 108 case LLVM::ICmpPredicate::sgt: 109 return llvm::CmpInst::Predicate::ICMP_SGT; 110 case LLVM::ICmpPredicate::sge: 111 return llvm::CmpInst::Predicate::ICMP_SGE; 112 case LLVM::ICmpPredicate::ult: 113 return llvm::CmpInst::Predicate::ICMP_ULT; 114 case LLVM::ICmpPredicate::ule: 115 return llvm::CmpInst::Predicate::ICMP_ULE; 116 case LLVM::ICmpPredicate::ugt: 117 return llvm::CmpInst::Predicate::ICMP_UGT; 118 case LLVM::ICmpPredicate::uge: 119 return llvm::CmpInst::Predicate::ICMP_UGE; 120 } 121 llvm_unreachable("incorrect comparison predicate"); 122 } 123 124 static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) { 125 switch (p) { 126 case LLVM::FCmpPredicate::_false: 127 return llvm::CmpInst::Predicate::FCMP_FALSE; 128 case LLVM::FCmpPredicate::oeq: 129 return llvm::CmpInst::Predicate::FCMP_OEQ; 130 case LLVM::FCmpPredicate::ogt: 131 return llvm::CmpInst::Predicate::FCMP_OGT; 132 case LLVM::FCmpPredicate::oge: 133 return llvm::CmpInst::Predicate::FCMP_OGE; 134 case LLVM::FCmpPredicate::olt: 135 return llvm::CmpInst::Predicate::FCMP_OLT; 136 case LLVM::FCmpPredicate::ole: 137 return llvm::CmpInst::Predicate::FCMP_OLE; 138 case LLVM::FCmpPredicate::one: 139 return llvm::CmpInst::Predicate::FCMP_ONE; 140 case LLVM::FCmpPredicate::ord: 141 return llvm::CmpInst::Predicate::FCMP_ORD; 142 case LLVM::FCmpPredicate::ueq: 143 return llvm::CmpInst::Predicate::FCMP_UEQ; 144 case LLVM::FCmpPredicate::ugt: 145 return llvm::CmpInst::Predicate::FCMP_UGT; 146 case LLVM::FCmpPredicate::uge: 147 return llvm::CmpInst::Predicate::FCMP_UGE; 148 case LLVM::FCmpPredicate::ult: 149 return llvm::CmpInst::Predicate::FCMP_ULT; 150 case LLVM::FCmpPredicate::ule: 151 return llvm::CmpInst::Predicate::FCMP_ULE; 152 case LLVM::FCmpPredicate::une: 153 return llvm::CmpInst::Predicate::FCMP_UNE; 154 case LLVM::FCmpPredicate::uno: 155 return llvm::CmpInst::Predicate::FCMP_UNO; 156 case LLVM::FCmpPredicate::_true: 157 return llvm::CmpInst::Predicate::FCMP_TRUE; 158 } 159 llvm_unreachable("incorrect comparison predicate"); 160 } 161 162 // Given a single MLIR operation, create the corresponding LLVM IR operation 163 // using the `builder`. LLVM IR Builder does not have a generic interface so 164 // this has to be a long chain of `if`s calling different functions with a 165 // different number of arguments. 166 LogicalResult ModuleTranslation::convertOperation(Operation &opInst, 167 llvm::IRBuilder<> &builder) { 168 auto extractPosition = [](ArrayAttr attr) { 169 SmallVector<unsigned, 4> position; 170 position.reserve(attr.size()); 171 for (Attribute v : attr) 172 position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue()); 173 return position; 174 }; 175 176 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" 177 178 // Emit function calls. If the "callee" attribute is present, this is a 179 // direct function call and we also need to look up the remapped function 180 // itself. Otherwise, this is an indirect call and the callee is the first 181 // operand, look it up as a normal value. Return the llvm::Value representing 182 // the function result, which may be of llvm::VoidTy type. 183 auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { 184 auto operands = lookupValues(op.getOperands()); 185 ArrayRef<llvm::Value *> operandsRef(operands); 186 if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee")) { 187 return builder.CreateCall(functionMapping.lookup(attr.getValue()), 188 operandsRef); 189 } else { 190 return builder.CreateCall(operandsRef.front(), operandsRef.drop_front()); 191 } 192 }; 193 194 // Emit calls. If the called function has a result, remap the corresponding 195 // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. 196 if (isa<LLVM::CallOp>(opInst)) { 197 llvm::Value *result = convertCall(opInst); 198 if (opInst.getNumResults() != 0) { 199 valueMapping[opInst.getResult(0)] = result; 200 return success(); 201 } 202 // Check that LLVM call returns void for 0-result functions. 203 return success(result->getType()->isVoidTy()); 204 } 205 206 // Emit branches. We need to look up the remapped blocks and ignore the block 207 // arguments that were transformed into PHI nodes. 208 if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) { 209 builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); 210 return success(); 211 } 212 if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) { 213 builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), 214 blockMapping[condbrOp.getSuccessor(0)], 215 blockMapping[condbrOp.getSuccessor(1)]); 216 return success(); 217 } 218 219 // Emit addressof. We need to look up the global value referenced by the 220 // operation and store it in the MLIR-to-LLVM value mapping. This does not 221 // emit any LLVM instruction. 222 if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) { 223 LLVM::GlobalOp global = addressOfOp.getGlobal(); 224 // The verifier should not have allowed this. 225 assert(global && "referencing an undefined global"); 226 227 valueMapping[addressOfOp.getResult()] = globalsMapping.lookup(global); 228 return success(); 229 } 230 231 return opInst.emitError("unsupported or non-LLVM operation: ") 232 << opInst.getName(); 233 } 234 235 // Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes 236 // to define values corresponding to the MLIR block arguments. These nodes 237 // are not connected to the source basic blocks, which may not exist yet. 238 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { 239 llvm::IRBuilder<> builder(blockMapping[&bb]); 240 241 // Before traversing operations, make block arguments available through 242 // value remapping and PHI nodes, but do not add incoming edges for the PHI 243 // nodes just yet: those values may be defined by this or following blocks. 244 // This step is omitted if "ignoreArguments" is set. The arguments of the 245 // first block have been already made available through the remapping of 246 // LLVM function arguments. 247 if (!ignoreArguments) { 248 auto predecessors = bb.getPredecessors(); 249 unsigned numPredecessors = 250 std::distance(predecessors.begin(), predecessors.end()); 251 for (auto *arg : bb.getArguments()) { 252 auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>(); 253 if (!wrappedType) 254 return emitError(bb.front().getLoc(), 255 "block argument does not have an LLVM type"); 256 llvm::Type *type = wrappedType.getUnderlyingType(); 257 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); 258 valueMapping[arg] = phi; 259 } 260 } 261 262 // Traverse operations. 263 for (auto &op : bb) { 264 if (failed(convertOperation(op, builder))) 265 return failure(); 266 } 267 268 return success(); 269 } 270 271 // Convert the LLVM dialect linkage type to LLVM IR linkage type. 272 llvm::GlobalVariable::LinkageTypes convertLinkageType(LLVM::Linkage linkage) { 273 switch (linkage) { 274 case LLVM::Linkage::Private: 275 return llvm::GlobalValue::PrivateLinkage; 276 case LLVM::Linkage::Internal: 277 return llvm::GlobalValue::InternalLinkage; 278 case LLVM::Linkage::AvailableExternally: 279 return llvm::GlobalValue::AvailableExternallyLinkage; 280 case LLVM::Linkage::Linkonce: 281 return llvm::GlobalValue::LinkOnceAnyLinkage; 282 case LLVM::Linkage::Weak: 283 return llvm::GlobalValue::WeakAnyLinkage; 284 case LLVM::Linkage::Common: 285 return llvm::GlobalValue::CommonLinkage; 286 case LLVM::Linkage::Appending: 287 return llvm::GlobalValue::AppendingLinkage; 288 case LLVM::Linkage::ExternWeak: 289 return llvm::GlobalValue::ExternalWeakLinkage; 290 case LLVM::Linkage::LinkonceODR: 291 return llvm::GlobalValue::LinkOnceODRLinkage; 292 case LLVM::Linkage::WeakODR: 293 return llvm::GlobalValue::WeakODRLinkage; 294 case LLVM::Linkage::External: 295 return llvm::GlobalValue::ExternalLinkage; 296 } 297 llvm_unreachable("unknown linkage type"); 298 } 299 300 // Create named global variables that correspond to llvm.mlir.global 301 // definitions. 302 void ModuleTranslation::convertGlobals() { 303 for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) { 304 llvm::Type *type = op.getType().getUnderlyingType(); 305 llvm::Constant *cst = llvm::UndefValue::get(type); 306 if (op.getValueOrNull()) { 307 // String attributes are treated separately because they cannot appear as 308 // in-function constants and are thus not supported by getLLVMConstant. 309 if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) { 310 cst = llvm::ConstantDataArray::getString( 311 llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false); 312 type = cst->getType(); 313 } else { 314 cst = getLLVMConstant(type, op.getValueOrNull(), op.getLoc()); 315 } 316 } else if (Block *initializer = op.getInitializerBlock()) { 317 llvm::IRBuilder<> builder(llvmModule->getContext()); 318 for (auto &op : initializer->without_terminator()) { 319 if (failed(convertOperation(op, builder)) || 320 !isa<llvm::Constant>(valueMapping.lookup(op.getResult(0)))) { 321 emitError(op.getLoc(), "unemittable constant value"); 322 return; 323 } 324 } 325 ReturnOp ret = cast<ReturnOp>(initializer->getTerminator()); 326 cst = cast<llvm::Constant>(valueMapping.lookup(ret.getOperand(0))); 327 } 328 329 auto linkage = convertLinkageType(op.linkage()); 330 bool anyExternalLinkage = 331 (linkage == llvm::GlobalVariable::ExternalLinkage || 332 linkage == llvm::GlobalVariable::ExternalWeakLinkage); 333 auto addrSpace = op.addr_space().getLimitedValue(); 334 auto *var = new llvm::GlobalVariable( 335 *llvmModule, type, op.constant(), linkage, 336 anyExternalLinkage ? nullptr : cst, op.sym_name(), 337 /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace); 338 339 globalsMapping.try_emplace(op, var); 340 } 341 } 342 343 // Get the SSA value passed to the current block from the terminator operation 344 // of its predecessor. 345 static Value *getPHISourceValue(Block *current, Block *pred, 346 unsigned numArguments, unsigned index) { 347 auto &terminator = *pred->getTerminator(); 348 if (isa<LLVM::BrOp>(terminator)) { 349 return terminator.getOperand(index); 350 } 351 352 // For conditional branches, we need to check if the current block is reached 353 // through the "true" or the "false" branch and take the relevant operands. 354 auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator); 355 assert(condBranchOp && 356 "only branch operations can be terminators of a block that " 357 "has successors"); 358 assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && 359 "successors with arguments in LLVM conditional branches must be " 360 "different blocks"); 361 362 return condBranchOp.getSuccessor(0) == current 363 ? terminator.getSuccessorOperand(0, index) 364 : terminator.getSuccessorOperand(1, index); 365 } 366 367 void ModuleTranslation::connectPHINodes(LLVMFuncOp func) { 368 // Skip the first block, it cannot be branched to and its arguments correspond 369 // to the arguments of the LLVM function. 370 for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { 371 Block *bb = &*it; 372 llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); 373 auto phis = llvmBB->phis(); 374 auto numArguments = bb->getNumArguments(); 375 assert(numArguments == std::distance(phis.begin(), phis.end())); 376 for (auto &numberedPhiNode : llvm::enumerate(phis)) { 377 auto &phiNode = numberedPhiNode.value(); 378 unsigned index = numberedPhiNode.index(); 379 for (auto *pred : bb->getPredecessors()) { 380 phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( 381 bb, pred, numArguments, index)), 382 blockMapping.lookup(pred)); 383 } 384 } 385 } 386 } 387 388 // TODO(mlir-team): implement an iterative version 389 static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) { 390 blocks.insert(b); 391 for (Block *bb : b->getSuccessors()) { 392 if (blocks.count(bb) == 0) 393 topologicalSortImpl(blocks, bb); 394 } 395 } 396 397 // Sort function blocks topologically. 398 static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) { 399 // For each blocks that has not been visited yet (i.e. that has no 400 // predecessors), add it to the list and traverse its successors in DFS 401 // preorder. 402 llvm::SetVector<Block *> blocks; 403 for (Block &b : f.getBlocks()) { 404 if (blocks.count(&b) == 0) 405 topologicalSortImpl(blocks, &b); 406 } 407 assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); 408 409 return blocks; 410 } 411 412 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { 413 // Clear the block and value mappings, they are only relevant within one 414 // function. 415 blockMapping.clear(); 416 valueMapping.clear(); 417 llvm::Function *llvmFunc = functionMapping.lookup(func.getName()); 418 // Add function arguments to the value remapping table. 419 // If there was noalias info then we decorate each argument accordingly. 420 unsigned int argIdx = 0; 421 for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) { 422 llvm::Argument &llvmArg = std::get<1>(kvp); 423 BlockArgument *mlirArg = std::get<0>(kvp); 424 425 if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) { 426 // NB: Attribute already verified to be boolean, so check if we can indeed 427 // attach the attribute to this argument, based on its type. 428 auto argTy = mlirArg->getType().dyn_cast<LLVM::LLVMType>(); 429 if (!argTy.getUnderlyingType()->isPointerTy()) 430 return func.emitError( 431 "llvm.noalias attribute attached to LLVM non-pointer argument"); 432 if (attr.getValue()) 433 llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias); 434 } 435 valueMapping[mlirArg] = &llvmArg; 436 argIdx++; 437 } 438 439 // First, create all blocks so we can jump to them. 440 llvm::LLVMContext &llvmContext = llvmFunc->getContext(); 441 for (auto &bb : func) { 442 auto *llvmBB = llvm::BasicBlock::Create(llvmContext); 443 llvmBB->insertInto(llvmFunc); 444 blockMapping[&bb] = llvmBB; 445 } 446 447 // Then, convert blocks one by one in topological order to ensure defs are 448 // converted before uses. 449 auto blocks = topologicalSort(func); 450 for (auto indexedBB : llvm::enumerate(blocks)) { 451 auto *bb = indexedBB.value(); 452 if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) 453 return failure(); 454 } 455 456 // Finally, after all blocks have been traversed and values mapped, connect 457 // the PHI nodes to the results of preceding blocks. 458 connectPHINodes(func); 459 return success(); 460 } 461 462 LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) { 463 for (Operation &o : getModuleBody(m).getOperations()) 464 if (!isa<LLVM::LLVMFuncOp>(&o) && !isa<LLVM::GlobalOp>(&o) && 465 !o.isKnownTerminator()) 466 return o.emitOpError("unsupported module-level operation"); 467 return success(); 468 } 469 470 LogicalResult ModuleTranslation::convertFunctions() { 471 // Declare all functions first because there may be function calls that form a 472 // call graph with cycles. 473 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 474 llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( 475 function.getName(), 476 cast<llvm::FunctionType>(function.getType().getUnderlyingType())); 477 assert(isa<llvm::Function>(llvmFuncCst.getCallee())); 478 functionMapping[function.getName()] = 479 cast<llvm::Function>(llvmFuncCst.getCallee()); 480 } 481 482 // Convert functions. 483 for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) { 484 // Ignore external functions. 485 if (function.isExternal()) 486 continue; 487 488 if (failed(convertOneFunction(function))) 489 return failure(); 490 } 491 492 return success(); 493 } 494 495 std::unique_ptr<llvm::Module> 496 ModuleTranslation::prepareLLVMModule(Operation *m) { 497 auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); 498 assert(dialect && "LLVM dialect must be registered"); 499 500 auto llvmModule = llvm::CloneModule(dialect->getLLVMModule()); 501 if (!llvmModule) 502 return nullptr; 503 504 llvm::LLVMContext &llvmContext = llvmModule->getContext(); 505 llvm::IRBuilder<> builder(llvmContext); 506 507 // Inject declarations for `malloc` and `free` functions that can be used in 508 // memref allocation/deallocation coming from standard ops lowering. 509 llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(), 510 builder.getInt64Ty()); 511 llvmModule->getOrInsertFunction("free", builder.getVoidTy(), 512 builder.getInt8PtrTy()); 513 514 return llvmModule; 515 } 516 517 } // namespace LLVM 518 } // namespace mlir 519