1 //===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion ---------*- C++ -*-===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements a translation between the MLIR LLVM dialect and LLVM IR.
19 //
20 //===----------------------------------------------------------------------===//
21 
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/LLVMIR/LLVMDialect.h"
25 #include "mlir/StandardOps/Ops.h"
26 #include "mlir/Support/FileUtilities.h"
27 #include "mlir/Support/LLVM.h"
28 #include "mlir/Target/LLVMIR.h"
29 #include "mlir/Translation.h"
30 
31 #include "llvm/ADT/SetVector.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DerivedTypes.h"
35 #include "llvm/IR/IRBuilder.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/Support/ToolOutputFile.h"
39 #include "llvm/Transforms/Utils/Cloning.h"
40 
41 using namespace mlir;
42 
43 namespace {
44 // Implementation class for module translation.  Holds a reference to the module
45 // being translated, and the mappings between the original and the translated
46 // functions, basic blocks and values.  It is practically easier to hold these
47 // mappings in one class since the conversion of control flow operations
48 // needs to look up block and function mappings.
49 class ModuleTranslation {
50 public:
51   // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
52   // LLVM IR module.  The MLIR LLVM IR dialect holds a pointer to an
53   // LLVMContext, the LLVM IR module will be created in that context.
54   static std::unique_ptr<llvm::Module> translateModule(Module &m);
55 
56 private:
57   explicit ModuleTranslation(Module &module) : mlirModule(module) {}
58 
59   bool convertFunctions();
60   bool convertOneFunction(Function &func);
61   void connectPHINodes(Function &func);
62   bool convertBlock(Block &bb, bool ignoreArguments);
63   bool convertOperation(Operation &op, llvm::IRBuilder<> &builder);
64 
65   template <typename Range>
66   SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
67 
68   llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
69                                   Location loc);
70 
71   // Original and translated module.
72   Module &mlirModule;
73   std::unique_ptr<llvm::Module> llvmModule;
74 
75   // Mappings between original and translated values, used for lookups.
76   llvm::DenseMap<Function *, llvm::Function *> functionMapping;
77   llvm::DenseMap<Value *, llvm::Value *> valueMapping;
78   llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
79 };
80 } // end anonymous namespace
81 
82 // Convert an MLIR function type to LLVM IR.  Arguments of the function must of
83 // MLIR LLVM IR dialect types.  Use `loc` as a location when reporting errors.
84 // Return nullptr on errors.
85 static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
86                                                FunctionType type, Location loc,
87                                                bool isVarArgs) {
88   assert(type && "expected non-null type");
89 
90   auto context = type.getContext();
91   if (type.getNumResults() > 1)
92     return context->emitError(loc,
93                               "LLVM functions can only have 0 or 1 result"),
94            nullptr;
95 
96   SmallVector<llvm::Type *, 8> argTypes;
97   argTypes.reserve(type.getNumInputs());
98   for (auto t : type.getInputs()) {
99     auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
100     if (!wrappedLLVMType)
101       return context->emitError(loc, "non-LLVM function argument type"),
102              nullptr;
103     argTypes.push_back(wrappedLLVMType.getUnderlyingType());
104   }
105 
106   if (type.getNumResults() == 0)
107     return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
108                                    isVarArgs);
109 
110   auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
111   if (!wrappedResultType)
112     return context->emitError(loc, "non-LLVM function result"), nullptr;
113 
114   return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
115                                  argTypes, isVarArgs);
116 }
117 
118 // Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
119 // This currently supports integer, floating point, splat and dense element
120 // attributes and combinations thereof.  In case of error, report it to `loc`
121 // and return nullptr.
122 llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
123                                                    Attribute attr,
124                                                    Location loc) {
125   if (auto intAttr = attr.dyn_cast<IntegerAttr>())
126     return llvm::ConstantInt::get(llvmType, intAttr.getValue());
127   if (auto floatAttr = attr.dyn_cast<FloatAttr>())
128     return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
129   if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
130     return functionMapping.lookup(funcAttr.getValue());
131   if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
132     auto *vectorType = cast<llvm::VectorType>(llvmType);
133     auto *child = getLLVMConstant(vectorType->getElementType(),
134                                   splatAttr.getValue(), loc);
135     return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
136   }
137   if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
138     auto *vectorType = cast<llvm::VectorType>(llvmType);
139     SmallVector<llvm::Constant *, 8> constants;
140     uint64_t numElements = vectorType->getNumElements();
141     constants.reserve(numElements);
142     SmallVector<Attribute, 8> nested;
143     denseAttr.getValues(nested);
144     for (auto n : nested) {
145       constants.push_back(
146           getLLVMConstant(vectorType->getElementType(), n, loc));
147       if (!constants.back())
148         return nullptr;
149     }
150     return llvm::ConstantVector::get(constants);
151   }
152   mlirModule.getContext()->emitError(loc, "unsupported constant value");
153   return nullptr;
154 }
155 
156 // Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
157 static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
158   switch (p) {
159   case CmpIPredicate::EQ:
160     return llvm::CmpInst::Predicate::ICMP_EQ;
161   case CmpIPredicate::NE:
162     return llvm::CmpInst::Predicate::ICMP_NE;
163   case CmpIPredicate::SLT:
164     return llvm::CmpInst::Predicate::ICMP_SLT;
165   case CmpIPredicate::SLE:
166     return llvm::CmpInst::Predicate::ICMP_SLE;
167   case CmpIPredicate::SGT:
168     return llvm::CmpInst::Predicate::ICMP_SGT;
169   case CmpIPredicate::SGE:
170     return llvm::CmpInst::Predicate::ICMP_SGE;
171   case CmpIPredicate::ULT:
172     return llvm::CmpInst::Predicate::ICMP_ULT;
173   case CmpIPredicate::ULE:
174     return llvm::CmpInst::Predicate::ICMP_ULE;
175   case CmpIPredicate::UGT:
176     return llvm::CmpInst::Predicate::ICMP_UGT;
177   case CmpIPredicate::UGE:
178     return llvm::CmpInst::Predicate::ICMP_UGE;
179   default:
180     llvm_unreachable("incorrect comparison predicate");
181   }
182 }
183 
184 // A helper to look up remapped operands in the value remapping table.
185 template <typename Range>
186 SmallVector<llvm::Value *, 8> ModuleTranslation::lookupValues(Range &&values) {
187   SmallVector<llvm::Value *, 8> remapped;
188   remapped.reserve(llvm::size(values));
189   for (Value *v : values) {
190     remapped.push_back(valueMapping.lookup(v));
191   }
192   return remapped;
193 }
194 
195 // Given a single MLIR operation, create the corresponding LLVM IR operation
196 // using the `builder`.  LLVM IR Builder does not have a generic interface so
197 // this has to be a long chain of `if`s calling different functions with a
198 // different number of arguments.
199 bool ModuleTranslation::convertOperation(Operation &opInst,
200                                          llvm::IRBuilder<> &builder) {
201   auto extractPosition = [](ArrayAttr attr) {
202     SmallVector<unsigned, 4> position;
203     position.reserve(attr.size());
204     for (Attribute v : attr)
205       position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
206     return position;
207   };
208 
209 #include "mlir/LLVMIR/LLVMConversions.inc"
210 
211   // Emit function calls.  If the "callee" attribute is present, this is a
212   // direct function call and we also need to look up the remapped function
213   // itself.  Otherwise, this is an indirect call and the callee is the first
214   // operand, look it up as a normal value.  Return the llvm::Value representing
215   // the function result, which may be of llvm::VoidTy type.
216   auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
217     auto operands = lookupValues(op.getOperands());
218     ArrayRef<llvm::Value *> operandsRef(operands);
219     if (auto attr = op.getAttrOfType<FunctionAttr>("callee")) {
220       return builder.CreateCall(functionMapping.lookup(attr.getValue()),
221                                 operandsRef);
222     } else {
223       return builder.CreateCall(operandsRef.front(), operandsRef.drop_front());
224     }
225   };
226 
227   // Emit calls.  If the called function has a result, remap the corresponding
228   // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
229   if (opInst.isa<LLVM::CallOp>()) {
230     llvm::Value *result = convertCall(opInst);
231     if (opInst.getNumResults() != 0) {
232       valueMapping[opInst.getResult(0)] = result;
233       return false;
234     }
235     // Check that LLVM call returns void for 0-result functions.
236     return !result->getType()->isVoidTy();
237   }
238 
239   // Emit branches.  We need to look up the remapped blocks and ignore the block
240   // arguments that were transformed into PHI nodes.
241   if (auto brOp = opInst.dyn_cast<LLVM::BrOp>()) {
242     builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
243     return false;
244   }
245   if (auto condbrOp = opInst.dyn_cast<LLVM::CondBrOp>()) {
246     builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
247                          blockMapping[condbrOp.getSuccessor(0)],
248                          blockMapping[condbrOp.getSuccessor(1)]);
249     return false;
250   }
251 
252   opInst.emitError("unsupported or non-LLVM operation: " +
253                    opInst.getName().getStringRef());
254   return true;
255 }
256 
257 // Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
258 // to define values corresponding to the MLIR block arguments.  These nodes
259 // are not connected to the source basic blocks, which may not exist yet.
260 bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
261   llvm::IRBuilder<> builder(blockMapping[&bb]);
262 
263   // Before traversing operations, make block arguments available through
264   // value remapping and PHI nodes, but do not add incoming edges for the PHI
265   // nodes just yet: those values may be defined by this or following blocks.
266   // This step is omitted if "ignoreArguments" is set.  The arguments of the
267   // first block have been already made available through the remapping of
268   // LLVM function arguments.
269   if (!ignoreArguments) {
270     auto predecessors = bb.getPredecessors();
271     unsigned numPredecessors =
272         std::distance(predecessors.begin(), predecessors.end());
273     for (auto *arg : bb.getArguments()) {
274       auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>();
275       if (!wrappedType) {
276         arg->getType().getContext()->emitError(
277             bb.front().getLoc(), "block argument does not have an LLVM type");
278         return true;
279       }
280       llvm::Type *type = wrappedType.getUnderlyingType();
281       llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
282       valueMapping[arg] = phi;
283     }
284   }
285 
286   // Traverse operations.
287   for (auto &op : bb) {
288     if (convertOperation(op, builder))
289       return true;
290   }
291 
292   return false;
293 }
294 
295 // Get the SSA value passed to the current block from the terminator operation
296 // of its predecessor.
297 static Value *getPHISourceValue(Block *current, Block *pred,
298                                 unsigned numArguments, unsigned index) {
299   auto &terminator = *pred->getTerminator();
300   if (terminator.isa<LLVM::BrOp>()) {
301     return terminator.getOperand(index);
302   }
303 
304   // For conditional branches, we need to check if the current block is reached
305   // through the "true" or the "false" branch and take the relevant operands.
306   auto condBranchOp = terminator.dyn_cast<LLVM::CondBrOp>();
307   assert(condBranchOp &&
308          "only branch operations can be terminators of a block that "
309          "has successors");
310   assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
311          "successors with arguments in LLVM conditional branches must be "
312          "different blocks");
313 
314   return condBranchOp.getSuccessor(0) == current
315              ? terminator.getSuccessorOperand(0, index)
316              : terminator.getSuccessorOperand(1, index);
317 }
318 
319 void ModuleTranslation::connectPHINodes(Function &func) {
320   // Skip the first block, it cannot be branched to and its arguments correspond
321   // to the arguments of the LLVM function.
322   for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
323     Block *bb = &*it;
324     llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
325     auto phis = llvmBB->phis();
326     auto numArguments = bb->getNumArguments();
327     assert(numArguments == std::distance(phis.begin(), phis.end()));
328     for (auto &numberedPhiNode : llvm::enumerate(phis)) {
329       auto &phiNode = numberedPhiNode.value();
330       unsigned index = numberedPhiNode.index();
331       for (auto *pred : bb->getPredecessors()) {
332         phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
333                                 bb, pred, numArguments, index)),
334                             blockMapping.lookup(pred));
335       }
336     }
337   }
338 }
339 
340 // TODO(mlir-team): implement an iterative version
341 static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
342   blocks.insert(b);
343   for (Block *bb : b->getSuccessors()) {
344     if (blocks.count(bb) == 0)
345       topologicalSortImpl(blocks, bb);
346   }
347 }
348 
349 // Sort function blocks topologically.
350 static llvm::SetVector<Block *> topologicalSort(Function &f) {
351   // For each blocks that has not been visited yet (i.e. that has no
352   // predecessors), add it to the list and traverse its successors in DFS
353   // preorder.
354   llvm::SetVector<Block *> blocks;
355   for (Block &b : f.getBlocks()) {
356     if (blocks.count(&b) == 0)
357       topologicalSortImpl(blocks, &b);
358   }
359   assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
360 
361   return blocks;
362 }
363 
364 bool ModuleTranslation::convertOneFunction(Function &func) {
365   // Clear the block and value mappings, they are only relevant within one
366   // function.
367   blockMapping.clear();
368   valueMapping.clear();
369   llvm::Function *llvmFunc = functionMapping.lookup(&func);
370   // Add function arguments to the value remapping table.
371   // If there was noalias info then we decorate each argument accordingly.
372   unsigned int argIdx = 0;
373   for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
374     llvm::Argument &llvmArg = std::get<1>(kvp);
375     BlockArgument *mlirArg = std::get<0>(kvp);
376 
377     if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) {
378       // NB: Attribute already verified to be boolean, so check if we can indeed
379       // attach the attribute to this argument, based on its type.
380       auto argTy = mlirArg->getType().dyn_cast<LLVM::LLVMType>();
381       if (!argTy.getUnderlyingType()->isPointerTy())
382         return argTy.getContext()->emitError(
383             func.getLoc(),
384             "llvm.noalias attribute attached to LLVM non-pointer argument");
385       if (attr.getValue())
386         llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
387     }
388     valueMapping[mlirArg] = &llvmArg;
389     argIdx++;
390   }
391 
392   // First, create all blocks so we can jump to them.
393   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
394   for (auto &bb : func) {
395     auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
396     llvmBB->insertInto(llvmFunc);
397     blockMapping[&bb] = llvmBB;
398   }
399 
400   // Then, convert blocks one by one in topological order to ensure defs are
401   // converted before uses.
402   auto blocks = topologicalSort(func);
403   for (auto indexedBB : llvm::enumerate(blocks)) {
404     auto *bb = indexedBB.value();
405     if (convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))
406       return true;
407   }
408 
409   // Finally, after all blocks have been traversed and values mapped, connect
410   // the PHI nodes to the results of preceding blocks.
411   connectPHINodes(func);
412   return false;
413 }
414 
415 bool ModuleTranslation::convertFunctions() {
416   // Declare all functions first because there may be function calls that form a
417   // call graph with cycles.
418   for (Function &function : mlirModule) {
419     Function *functionPtr = &function;
420     mlir::BoolAttr isVarArgsAttr =
421         function.getAttrOfType<BoolAttr>("std.varargs");
422     bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
423     llvm::FunctionType *functionType =
424         convertFunctionType(llvmModule->getContext(), function.getType(),
425                             function.getLoc(), isVarArgs);
426     if (!functionType)
427       return true;
428     llvm::FunctionCallee llvmFuncCst =
429         llvmModule->getOrInsertFunction(function.getName(), functionType);
430     assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
431     functionMapping[functionPtr] =
432         cast<llvm::Function>(llvmFuncCst.getCallee());
433   }
434 
435   // Convert functions.
436   for (Function &function : mlirModule) {
437     // Ignore external functions.
438     if (function.isExternal())
439       continue;
440 
441     if (convertOneFunction(function))
442       return true;
443   }
444 
445   return false;
446 }
447 
448 std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) {
449   Dialect *dialect = m.getContext()->getRegisteredDialect("llvm");
450   assert(dialect && "LLVM dialect must be registered");
451   auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(dialect);
452 
453   auto llvmModule = llvm::CloneModule(llvmDialect->getLLVMModule());
454   if (!llvmModule)
455     return nullptr;
456 
457   llvm::LLVMContext &llvmContext = llvmModule->getContext();
458   llvm::IRBuilder<> builder(llvmContext);
459 
460   // Inject declarations for `malloc` and `free` functions that can be used in
461   // memref allocation/deallocation coming from standard ops lowering.
462   llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
463                                   builder.getInt64Ty());
464   llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
465                                   builder.getInt8PtrTy());
466 
467   ModuleTranslation translator(m);
468   translator.llvmModule = std::move(llvmModule);
469   if (translator.convertFunctions())
470     return nullptr;
471 
472   return std::move(translator.llvmModule);
473 }
474 
475 std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
476   return ModuleTranslation::translateModule(m);
477 }
478 
479 static TranslateFromMLIRRegistration registration(
480     "mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) {
481       if (!module)
482         return true;
483 
484       auto llvmModule = ModuleTranslation::translateModule(*module);
485       if (!llvmModule)
486         return true;
487 
488       auto file = openOutputFile(outputFilename);
489       if (!file)
490         return true;
491 
492       llvmModule->print(file->os(), nullptr);
493       file->keep();
494       return false;
495     });
496