1 //===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion ---------*- C++ -*-===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements a translation between the MLIR LLVM dialect and LLVM IR. 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "mlir/IR/Attributes.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/LLVMIR/LLVMDialect.h" 25 #include "mlir/StandardOps/Ops.h" 26 #include "mlir/Support/FileUtilities.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Target/LLVMIR.h" 29 #include "mlir/Translation.h" 30 31 #include "llvm/ADT/SetVector.h" 32 #include "llvm/IR/BasicBlock.h" 33 #include "llvm/IR/Constants.h" 34 #include "llvm/IR/DerivedTypes.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/LLVMContext.h" 37 #include "llvm/IR/Module.h" 38 #include "llvm/Support/ToolOutputFile.h" 39 #include "llvm/Transforms/Utils/Cloning.h" 40 41 using namespace mlir; 42 43 namespace { 44 // Implementation class for module translation. Holds a reference to the module 45 // being translated, and the mappings between the original and the translated 46 // functions, basic blocks and values. It is practically easier to hold these 47 // mappings in one class since the conversion of control flow operations 48 // needs to look up block and function mappings. 49 class ModuleTranslation { 50 public: 51 // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an 52 // LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an 53 // LLVMContext, the LLVM IR module will be created in that context. 54 static std::unique_ptr<llvm::Module> translateModule(Module &m); 55 56 private: 57 explicit ModuleTranslation(Module &module) : mlirModule(module) {} 58 59 bool convertFunctions(); 60 bool convertOneFunction(Function &func); 61 void connectPHINodes(Function &func); 62 bool convertBlock(Block &bb, bool ignoreArguments); 63 bool convertOperation(Operation &op, llvm::IRBuilder<> &builder); 64 65 template <typename Range> 66 SmallVector<llvm::Value *, 8> lookupValues(Range &&values); 67 68 llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, 69 Location loc); 70 71 // Original and translated module. 72 Module &mlirModule; 73 std::unique_ptr<llvm::Module> llvmModule; 74 75 // Mappings between original and translated values, used for lookups. 76 llvm::DenseMap<Function *, llvm::Function *> functionMapping; 77 llvm::DenseMap<Value *, llvm::Value *> valueMapping; 78 llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping; 79 }; 80 } // end anonymous namespace 81 82 // Convert an MLIR function type to LLVM IR. Arguments of the function must of 83 // MLIR LLVM IR dialect types. Use `loc` as a location when reporting errors. 84 // Return nullptr on errors. 85 static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext, 86 FunctionType type, Location loc, 87 bool isVarArgs) { 88 assert(type && "expected non-null type"); 89 90 auto context = type.getContext(); 91 if (type.getNumResults() > 1) 92 return context->emitError(loc, 93 "LLVM functions can only have 0 or 1 result"), 94 nullptr; 95 96 SmallVector<llvm::Type *, 8> argTypes; 97 argTypes.reserve(type.getNumInputs()); 98 for (auto t : type.getInputs()) { 99 auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>(); 100 if (!wrappedLLVMType) 101 return context->emitError(loc, "non-LLVM function argument type"), 102 nullptr; 103 argTypes.push_back(wrappedLLVMType.getUnderlyingType()); 104 } 105 106 if (type.getNumResults() == 0) 107 return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes, 108 isVarArgs); 109 110 auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>(); 111 if (!wrappedResultType) 112 return context->emitError(loc, "non-LLVM function result"), nullptr; 113 114 return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(), 115 argTypes, isVarArgs); 116 } 117 118 // Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`. 119 // This currently supports integer, floating point, splat and dense element 120 // attributes and combinations thereof. In case of error, report it to `loc` 121 // and return nullptr. 122 llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, 123 Attribute attr, 124 Location loc) { 125 if (auto intAttr = attr.dyn_cast<IntegerAttr>()) 126 return llvm::ConstantInt::get(llvmType, intAttr.getValue()); 127 if (auto floatAttr = attr.dyn_cast<FloatAttr>()) 128 return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); 129 if (auto funcAttr = attr.dyn_cast<FunctionAttr>()) 130 return functionMapping.lookup(funcAttr.getValue()); 131 if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) { 132 auto *vectorType = cast<llvm::VectorType>(llvmType); 133 auto *child = getLLVMConstant(vectorType->getElementType(), 134 splatAttr.getValue(), loc); 135 return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child); 136 } 137 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) { 138 auto *vectorType = cast<llvm::VectorType>(llvmType); 139 SmallVector<llvm::Constant *, 8> constants; 140 uint64_t numElements = vectorType->getNumElements(); 141 constants.reserve(numElements); 142 SmallVector<Attribute, 8> nested; 143 denseAttr.getValues(nested); 144 for (auto n : nested) { 145 constants.push_back( 146 getLLVMConstant(vectorType->getElementType(), n, loc)); 147 if (!constants.back()) 148 return nullptr; 149 } 150 return llvm::ConstantVector::get(constants); 151 } 152 mlirModule.getContext()->emitError(loc, "unsupported constant value"); 153 return nullptr; 154 } 155 156 // Convert MLIR integer comparison predicate to LLVM IR comparison predicate. 157 static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) { 158 switch (p) { 159 case CmpIPredicate::EQ: 160 return llvm::CmpInst::Predicate::ICMP_EQ; 161 case CmpIPredicate::NE: 162 return llvm::CmpInst::Predicate::ICMP_NE; 163 case CmpIPredicate::SLT: 164 return llvm::CmpInst::Predicate::ICMP_SLT; 165 case CmpIPredicate::SLE: 166 return llvm::CmpInst::Predicate::ICMP_SLE; 167 case CmpIPredicate::SGT: 168 return llvm::CmpInst::Predicate::ICMP_SGT; 169 case CmpIPredicate::SGE: 170 return llvm::CmpInst::Predicate::ICMP_SGE; 171 case CmpIPredicate::ULT: 172 return llvm::CmpInst::Predicate::ICMP_ULT; 173 case CmpIPredicate::ULE: 174 return llvm::CmpInst::Predicate::ICMP_ULE; 175 case CmpIPredicate::UGT: 176 return llvm::CmpInst::Predicate::ICMP_UGT; 177 case CmpIPredicate::UGE: 178 return llvm::CmpInst::Predicate::ICMP_UGE; 179 default: 180 llvm_unreachable("incorrect comparison predicate"); 181 } 182 } 183 184 // A helper to look up remapped operands in the value remapping table. 185 template <typename Range> 186 SmallVector<llvm::Value *, 8> ModuleTranslation::lookupValues(Range &&values) { 187 SmallVector<llvm::Value *, 8> remapped; 188 remapped.reserve(llvm::size(values)); 189 for (Value *v : values) { 190 remapped.push_back(valueMapping.lookup(v)); 191 } 192 return remapped; 193 } 194 195 // Given a single MLIR operation, create the corresponding LLVM IR operation 196 // using the `builder`. LLVM IR Builder does not have a generic interface so 197 // this has to be a long chain of `if`s calling different functions with a 198 // different number of arguments. 199 bool ModuleTranslation::convertOperation(Operation &opInst, 200 llvm::IRBuilder<> &builder) { 201 auto extractPosition = [](ArrayAttr attr) { 202 SmallVector<unsigned, 4> position; 203 position.reserve(attr.size()); 204 for (Attribute v : attr) 205 position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue()); 206 return position; 207 }; 208 209 #include "mlir/LLVMIR/LLVMConversions.inc" 210 211 // Emit function calls. If the "callee" attribute is present, this is a 212 // direct function call and we also need to look up the remapped function 213 // itself. Otherwise, this is an indirect call and the callee is the first 214 // operand, look it up as a normal value. Return the llvm::Value representing 215 // the function result, which may be of llvm::VoidTy type. 216 auto convertCall = [this, &builder](Operation &op) -> llvm::Value * { 217 auto operands = lookupValues(op.getOperands()); 218 ArrayRef<llvm::Value *> operandsRef(operands); 219 if (auto attr = op.getAttrOfType<FunctionAttr>("callee")) { 220 return builder.CreateCall(functionMapping.lookup(attr.getValue()), 221 operandsRef); 222 } else { 223 return builder.CreateCall(operandsRef.front(), operandsRef.drop_front()); 224 } 225 }; 226 227 // Emit calls. If the called function has a result, remap the corresponding 228 // value. Note that LLVM IR dialect CallOp has either 0 or 1 result. 229 if (opInst.isa<LLVM::CallOp>()) { 230 llvm::Value *result = convertCall(opInst); 231 if (opInst.getNumResults() != 0) { 232 valueMapping[opInst.getResult(0)] = result; 233 return false; 234 } 235 // Check that LLVM call returns void for 0-result functions. 236 return !result->getType()->isVoidTy(); 237 } 238 239 // Emit branches. We need to look up the remapped blocks and ignore the block 240 // arguments that were transformed into PHI nodes. 241 if (auto brOp = opInst.dyn_cast<LLVM::BrOp>()) { 242 builder.CreateBr(blockMapping[brOp.getSuccessor(0)]); 243 return false; 244 } 245 if (auto condbrOp = opInst.dyn_cast<LLVM::CondBrOp>()) { 246 builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)), 247 blockMapping[condbrOp.getSuccessor(0)], 248 blockMapping[condbrOp.getSuccessor(1)]); 249 return false; 250 } 251 252 opInst.emitError("unsupported or non-LLVM operation: " + 253 opInst.getName().getStringRef()); 254 return true; 255 } 256 257 // Convert block to LLVM IR. Unless `ignoreArguments` is set, emit PHI nodes 258 // to define values corresponding to the MLIR block arguments. These nodes 259 // are not connected to the source basic blocks, which may not exist yet. 260 bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) { 261 llvm::IRBuilder<> builder(blockMapping[&bb]); 262 263 // Before traversing operations, make block arguments available through 264 // value remapping and PHI nodes, but do not add incoming edges for the PHI 265 // nodes just yet: those values may be defined by this or following blocks. 266 // This step is omitted if "ignoreArguments" is set. The arguments of the 267 // first block have been already made available through the remapping of 268 // LLVM function arguments. 269 if (!ignoreArguments) { 270 auto predecessors = bb.getPredecessors(); 271 unsigned numPredecessors = 272 std::distance(predecessors.begin(), predecessors.end()); 273 for (auto *arg : bb.getArguments()) { 274 auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>(); 275 if (!wrappedType) { 276 arg->getType().getContext()->emitError( 277 bb.front().getLoc(), "block argument does not have an LLVM type"); 278 return true; 279 } 280 llvm::Type *type = wrappedType.getUnderlyingType(); 281 llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); 282 valueMapping[arg] = phi; 283 } 284 } 285 286 // Traverse operations. 287 for (auto &op : bb) { 288 if (convertOperation(op, builder)) 289 return true; 290 } 291 292 return false; 293 } 294 295 // Get the SSA value passed to the current block from the terminator operation 296 // of its predecessor. 297 static Value *getPHISourceValue(Block *current, Block *pred, 298 unsigned numArguments, unsigned index) { 299 auto &terminator = *pred->getTerminator(); 300 if (terminator.isa<LLVM::BrOp>()) { 301 return terminator.getOperand(index); 302 } 303 304 // For conditional branches, we need to check if the current block is reached 305 // through the "true" or the "false" branch and take the relevant operands. 306 auto condBranchOp = terminator.dyn_cast<LLVM::CondBrOp>(); 307 assert(condBranchOp && 308 "only branch operations can be terminators of a block that " 309 "has successors"); 310 assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && 311 "successors with arguments in LLVM conditional branches must be " 312 "different blocks"); 313 314 return condBranchOp.getSuccessor(0) == current 315 ? terminator.getSuccessorOperand(0, index) 316 : terminator.getSuccessorOperand(1, index); 317 } 318 319 void ModuleTranslation::connectPHINodes(Function &func) { 320 // Skip the first block, it cannot be branched to and its arguments correspond 321 // to the arguments of the LLVM function. 322 for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { 323 Block *bb = &*it; 324 llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); 325 auto phis = llvmBB->phis(); 326 auto numArguments = bb->getNumArguments(); 327 assert(numArguments == std::distance(phis.begin(), phis.end())); 328 for (auto &numberedPhiNode : llvm::enumerate(phis)) { 329 auto &phiNode = numberedPhiNode.value(); 330 unsigned index = numberedPhiNode.index(); 331 for (auto *pred : bb->getPredecessors()) { 332 phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( 333 bb, pred, numArguments, index)), 334 blockMapping.lookup(pred)); 335 } 336 } 337 } 338 } 339 340 // TODO(mlir-team): implement an iterative version 341 static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) { 342 blocks.insert(b); 343 for (Block *bb : b->getSuccessors()) { 344 if (blocks.count(bb) == 0) 345 topologicalSortImpl(blocks, bb); 346 } 347 } 348 349 // Sort function blocks topologically. 350 static llvm::SetVector<Block *> topologicalSort(Function &f) { 351 // For each blocks that has not been visited yet (i.e. that has no 352 // predecessors), add it to the list and traverse its successors in DFS 353 // preorder. 354 llvm::SetVector<Block *> blocks; 355 for (Block &b : f.getBlocks()) { 356 if (blocks.count(&b) == 0) 357 topologicalSortImpl(blocks, &b); 358 } 359 assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); 360 361 return blocks; 362 } 363 364 bool ModuleTranslation::convertOneFunction(Function &func) { 365 // Clear the block and value mappings, they are only relevant within one 366 // function. 367 blockMapping.clear(); 368 valueMapping.clear(); 369 llvm::Function *llvmFunc = functionMapping.lookup(&func); 370 // Add function arguments to the value remapping table. 371 // If there was noalias info then we decorate each argument accordingly. 372 unsigned int argIdx = 0; 373 for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) { 374 llvm::Argument &llvmArg = std::get<1>(kvp); 375 BlockArgument *mlirArg = std::get<0>(kvp); 376 377 if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) { 378 // NB: Attribute already verified to be boolean, so check if we can indeed 379 // attach the attribute to this argument, based on its type. 380 auto argTy = mlirArg->getType().dyn_cast<LLVM::LLVMType>(); 381 if (!argTy.getUnderlyingType()->isPointerTy()) 382 return argTy.getContext()->emitError( 383 func.getLoc(), 384 "llvm.noalias attribute attached to LLVM non-pointer argument"); 385 if (attr.getValue()) 386 llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias); 387 } 388 valueMapping[mlirArg] = &llvmArg; 389 argIdx++; 390 } 391 392 // First, create all blocks so we can jump to them. 393 llvm::LLVMContext &llvmContext = llvmFunc->getContext(); 394 for (auto &bb : func) { 395 auto *llvmBB = llvm::BasicBlock::Create(llvmContext); 396 llvmBB->insertInto(llvmFunc); 397 blockMapping[&bb] = llvmBB; 398 } 399 400 // Then, convert blocks one by one in topological order to ensure defs are 401 // converted before uses. 402 auto blocks = topologicalSort(func); 403 for (auto indexedBB : llvm::enumerate(blocks)) { 404 auto *bb = indexedBB.value(); 405 if (convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)) 406 return true; 407 } 408 409 // Finally, after all blocks have been traversed and values mapped, connect 410 // the PHI nodes to the results of preceding blocks. 411 connectPHINodes(func); 412 return false; 413 } 414 415 bool ModuleTranslation::convertFunctions() { 416 // Declare all functions first because there may be function calls that form a 417 // call graph with cycles. 418 for (Function &function : mlirModule) { 419 Function *functionPtr = &function; 420 mlir::BoolAttr isVarArgsAttr = 421 function.getAttrOfType<BoolAttr>("std.varargs"); 422 bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue(); 423 llvm::FunctionType *functionType = 424 convertFunctionType(llvmModule->getContext(), function.getType(), 425 function.getLoc(), isVarArgs); 426 if (!functionType) 427 return true; 428 llvm::FunctionCallee llvmFuncCst = 429 llvmModule->getOrInsertFunction(function.getName(), functionType); 430 assert(isa<llvm::Function>(llvmFuncCst.getCallee())); 431 functionMapping[functionPtr] = 432 cast<llvm::Function>(llvmFuncCst.getCallee()); 433 } 434 435 // Convert functions. 436 for (Function &function : mlirModule) { 437 // Ignore external functions. 438 if (function.isExternal()) 439 continue; 440 441 if (convertOneFunction(function)) 442 return true; 443 } 444 445 return false; 446 } 447 448 std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) { 449 Dialect *dialect = m.getContext()->getRegisteredDialect("llvm"); 450 assert(dialect && "LLVM dialect must be registered"); 451 auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(dialect); 452 453 auto llvmModule = llvm::CloneModule(llvmDialect->getLLVMModule()); 454 if (!llvmModule) 455 return nullptr; 456 457 llvm::LLVMContext &llvmContext = llvmModule->getContext(); 458 llvm::IRBuilder<> builder(llvmContext); 459 460 // Inject declarations for `malloc` and `free` functions that can be used in 461 // memref allocation/deallocation coming from standard ops lowering. 462 llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(), 463 builder.getInt64Ty()); 464 llvmModule->getOrInsertFunction("free", builder.getVoidTy(), 465 builder.getInt8PtrTy()); 466 467 ModuleTranslation translator(m); 468 translator.llvmModule = std::move(llvmModule); 469 if (translator.convertFunctions()) 470 return nullptr; 471 472 return std::move(translator.llvmModule); 473 } 474 475 std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) { 476 return ModuleTranslation::translateModule(m); 477 } 478 479 static TranslateFromMLIRRegistration registration( 480 "mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) { 481 if (!module) 482 return true; 483 484 auto llvmModule = ModuleTranslation::translateModule(*module); 485 if (!llvmModule) 486 return true; 487 488 auto file = openOutputFile(outputFilename); 489 if (!file) 490 return true; 491 492 llvmModule->print(file->os(), nullptr); 493 file->keep(); 494 return false; 495 }); 496