1 //===- ModuleTranslation.cpp - MLIR to LLVM conversion --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the translation between an MLIR LLVM dialect module and
10 // the corresponding LLVMIR module. It only handles core LLVM IR operations.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
15 
16 #include "DebugTranslation.h"
17 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/RegionGraphTraits.h"
23 #include "mlir/Support/LLVM.h"
24 #include "mlir/Target/LLVMIR/TypeTranslation.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
30 #include "llvm/IR/BasicBlock.h"
31 #include "llvm/IR/CFG.h"
32 #include "llvm/IR/Constants.h"
33 #include "llvm/IR/DerivedTypes.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/InlineAsm.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/MDBuilder.h"
38 #include "llvm/IR/Module.h"
39 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
40 #include "llvm/Transforms/Utils/Cloning.h"
41 
42 using namespace mlir;
43 using namespace mlir::LLVM;
44 using namespace mlir::LLVM::detail;
45 
46 #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
47 
48 /// Builds a constant of a sequential LLVM type `type`, potentially containing
49 /// other sequential types recursively, from the individual constant values
50 /// provided in `constants`. `shape` contains the number of elements in nested
51 /// sequential types. Reports errors at `loc` and returns nullptr on error.
52 static llvm::Constant *
53 buildSequentialConstant(ArrayRef<llvm::Constant *> &constants,
54                         ArrayRef<int64_t> shape, llvm::Type *type,
55                         Location loc) {
56   if (shape.empty()) {
57     llvm::Constant *result = constants.front();
58     constants = constants.drop_front();
59     return result;
60   }
61 
62   llvm::Type *elementType;
63   if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
64     elementType = arrayTy->getElementType();
65   } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
66     elementType = vectorTy->getElementType();
67   } else {
68     emitError(loc) << "expected sequential LLVM types wrapping a scalar";
69     return nullptr;
70   }
71 
72   SmallVector<llvm::Constant *, 8> nested;
73   nested.reserve(shape.front());
74   for (int64_t i = 0; i < shape.front(); ++i) {
75     nested.push_back(buildSequentialConstant(constants, shape.drop_front(),
76                                              elementType, loc));
77     if (!nested.back())
78       return nullptr;
79   }
80 
81   if (shape.size() == 1 && type->isVectorTy())
82     return llvm::ConstantVector::get(nested);
83   return llvm::ConstantArray::get(
84       llvm::ArrayType::get(elementType, shape.front()), nested);
85 }
86 
87 /// Returns the first non-sequential type nested in sequential types.
88 static llvm::Type *getInnermostElementType(llvm::Type *type) {
89   do {
90     if (auto *arrayTy = dyn_cast<llvm::ArrayType>(type)) {
91       type = arrayTy->getElementType();
92     } else if (auto *vectorTy = dyn_cast<llvm::VectorType>(type)) {
93       type = vectorTy->getElementType();
94     } else {
95       return type;
96     }
97   } while (true);
98 }
99 
100 /// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
101 /// This currently supports integer, floating point, splat and dense element
102 /// attributes and combinations thereof.  In case of error, report it to `loc`
103 /// and return nullptr.
104 llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
105                                                    Attribute attr,
106                                                    Location loc) {
107   if (!attr)
108     return llvm::UndefValue::get(llvmType);
109   if (llvmType->isStructTy()) {
110     emitError(loc, "struct types are not supported in constants");
111     return nullptr;
112   }
113   // For integer types, we allow a mismatch in sizes as the index type in
114   // MLIR might have a different size than the index type in the LLVM module.
115   if (auto intAttr = attr.dyn_cast<IntegerAttr>())
116     return llvm::ConstantInt::get(
117         llvmType,
118         intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));
119   if (auto floatAttr = attr.dyn_cast<FloatAttr>())
120     return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
121   if (auto funcAttr = attr.dyn_cast<FlatSymbolRefAttr>())
122     return llvm::ConstantExpr::getBitCast(lookupFunction(funcAttr.getValue()),
123                                           llvmType);
124   if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
125     llvm::Type *elementType;
126     uint64_t numElements;
127     if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
128       elementType = arrayTy->getElementType();
129       numElements = arrayTy->getNumElements();
130     } else {
131       auto *vectorTy = cast<llvm::FixedVectorType>(llvmType);
132       elementType = vectorTy->getElementType();
133       numElements = vectorTy->getNumElements();
134     }
135     // Splat value is a scalar. Extract it only if the element type is not
136     // another sequence type. The recursion terminates because each step removes
137     // one outer sequential type.
138     bool elementTypeSequential =
139         isa<llvm::ArrayType, llvm::VectorType>(elementType);
140     llvm::Constant *child = getLLVMConstant(
141         elementType,
142         elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc);
143     if (!child)
144       return nullptr;
145     if (llvmType->isVectorTy())
146       return llvm::ConstantVector::getSplat(
147           llvm::ElementCount::get(numElements, /*Scalable=*/false), child);
148     if (llvmType->isArrayTy()) {
149       auto *arrayType = llvm::ArrayType::get(elementType, numElements);
150       SmallVector<llvm::Constant *, 8> constants(numElements, child);
151       return llvm::ConstantArray::get(arrayType, constants);
152     }
153   }
154 
155   if (auto elementsAttr = attr.dyn_cast<ElementsAttr>()) {
156     assert(elementsAttr.getType().hasStaticShape());
157     assert(elementsAttr.getNumElements() != 0 &&
158            "unexpected empty elements attribute");
159     assert(!elementsAttr.getType().getShape().empty() &&
160            "unexpected empty elements attribute shape");
161 
162     SmallVector<llvm::Constant *, 8> constants;
163     constants.reserve(elementsAttr.getNumElements());
164     llvm::Type *innermostType = getInnermostElementType(llvmType);
165     for (auto n : elementsAttr.getValues<Attribute>()) {
166       constants.push_back(getLLVMConstant(innermostType, n, loc));
167       if (!constants.back())
168         return nullptr;
169     }
170     ArrayRef<llvm::Constant *> constantsRef = constants;
171     llvm::Constant *result = buildSequentialConstant(
172         constantsRef, elementsAttr.getType().getShape(), llvmType, loc);
173     assert(constantsRef.empty() && "did not consume all elemental constants");
174     return result;
175   }
176 
177   if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
178     return llvm::ConstantDataArray::get(
179         llvmModule->getContext(), ArrayRef<char>{stringAttr.getValue().data(),
180                                                  stringAttr.getValue().size()});
181   }
182   emitError(loc, "unsupported constant value");
183   return nullptr;
184 }
185 
186 /// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
187 static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
188   switch (p) {
189   case LLVM::ICmpPredicate::eq:
190     return llvm::CmpInst::Predicate::ICMP_EQ;
191   case LLVM::ICmpPredicate::ne:
192     return llvm::CmpInst::Predicate::ICMP_NE;
193   case LLVM::ICmpPredicate::slt:
194     return llvm::CmpInst::Predicate::ICMP_SLT;
195   case LLVM::ICmpPredicate::sle:
196     return llvm::CmpInst::Predicate::ICMP_SLE;
197   case LLVM::ICmpPredicate::sgt:
198     return llvm::CmpInst::Predicate::ICMP_SGT;
199   case LLVM::ICmpPredicate::sge:
200     return llvm::CmpInst::Predicate::ICMP_SGE;
201   case LLVM::ICmpPredicate::ult:
202     return llvm::CmpInst::Predicate::ICMP_ULT;
203   case LLVM::ICmpPredicate::ule:
204     return llvm::CmpInst::Predicate::ICMP_ULE;
205   case LLVM::ICmpPredicate::ugt:
206     return llvm::CmpInst::Predicate::ICMP_UGT;
207   case LLVM::ICmpPredicate::uge:
208     return llvm::CmpInst::Predicate::ICMP_UGE;
209   }
210   llvm_unreachable("incorrect comparison predicate");
211 }
212 
213 static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
214   switch (p) {
215   case LLVM::FCmpPredicate::_false:
216     return llvm::CmpInst::Predicate::FCMP_FALSE;
217   case LLVM::FCmpPredicate::oeq:
218     return llvm::CmpInst::Predicate::FCMP_OEQ;
219   case LLVM::FCmpPredicate::ogt:
220     return llvm::CmpInst::Predicate::FCMP_OGT;
221   case LLVM::FCmpPredicate::oge:
222     return llvm::CmpInst::Predicate::FCMP_OGE;
223   case LLVM::FCmpPredicate::olt:
224     return llvm::CmpInst::Predicate::FCMP_OLT;
225   case LLVM::FCmpPredicate::ole:
226     return llvm::CmpInst::Predicate::FCMP_OLE;
227   case LLVM::FCmpPredicate::one:
228     return llvm::CmpInst::Predicate::FCMP_ONE;
229   case LLVM::FCmpPredicate::ord:
230     return llvm::CmpInst::Predicate::FCMP_ORD;
231   case LLVM::FCmpPredicate::ueq:
232     return llvm::CmpInst::Predicate::FCMP_UEQ;
233   case LLVM::FCmpPredicate::ugt:
234     return llvm::CmpInst::Predicate::FCMP_UGT;
235   case LLVM::FCmpPredicate::uge:
236     return llvm::CmpInst::Predicate::FCMP_UGE;
237   case LLVM::FCmpPredicate::ult:
238     return llvm::CmpInst::Predicate::FCMP_ULT;
239   case LLVM::FCmpPredicate::ule:
240     return llvm::CmpInst::Predicate::FCMP_ULE;
241   case LLVM::FCmpPredicate::une:
242     return llvm::CmpInst::Predicate::FCMP_UNE;
243   case LLVM::FCmpPredicate::uno:
244     return llvm::CmpInst::Predicate::FCMP_UNO;
245   case LLVM::FCmpPredicate::_true:
246     return llvm::CmpInst::Predicate::FCMP_TRUE;
247   }
248   llvm_unreachable("incorrect comparison predicate");
249 }
250 
251 static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
252   switch (op) {
253   case LLVM::AtomicBinOp::xchg:
254     return llvm::AtomicRMWInst::BinOp::Xchg;
255   case LLVM::AtomicBinOp::add:
256     return llvm::AtomicRMWInst::BinOp::Add;
257   case LLVM::AtomicBinOp::sub:
258     return llvm::AtomicRMWInst::BinOp::Sub;
259   case LLVM::AtomicBinOp::_and:
260     return llvm::AtomicRMWInst::BinOp::And;
261   case LLVM::AtomicBinOp::nand:
262     return llvm::AtomicRMWInst::BinOp::Nand;
263   case LLVM::AtomicBinOp::_or:
264     return llvm::AtomicRMWInst::BinOp::Or;
265   case LLVM::AtomicBinOp::_xor:
266     return llvm::AtomicRMWInst::BinOp::Xor;
267   case LLVM::AtomicBinOp::max:
268     return llvm::AtomicRMWInst::BinOp::Max;
269   case LLVM::AtomicBinOp::min:
270     return llvm::AtomicRMWInst::BinOp::Min;
271   case LLVM::AtomicBinOp::umax:
272     return llvm::AtomicRMWInst::BinOp::UMax;
273   case LLVM::AtomicBinOp::umin:
274     return llvm::AtomicRMWInst::BinOp::UMin;
275   case LLVM::AtomicBinOp::fadd:
276     return llvm::AtomicRMWInst::BinOp::FAdd;
277   case LLVM::AtomicBinOp::fsub:
278     return llvm::AtomicRMWInst::BinOp::FSub;
279   }
280   llvm_unreachable("incorrect atomic binary operator");
281 }
282 
283 static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
284   switch (ordering) {
285   case LLVM::AtomicOrdering::not_atomic:
286     return llvm::AtomicOrdering::NotAtomic;
287   case LLVM::AtomicOrdering::unordered:
288     return llvm::AtomicOrdering::Unordered;
289   case LLVM::AtomicOrdering::monotonic:
290     return llvm::AtomicOrdering::Monotonic;
291   case LLVM::AtomicOrdering::acquire:
292     return llvm::AtomicOrdering::Acquire;
293   case LLVM::AtomicOrdering::release:
294     return llvm::AtomicOrdering::Release;
295   case LLVM::AtomicOrdering::acq_rel:
296     return llvm::AtomicOrdering::AcquireRelease;
297   case LLVM::AtomicOrdering::seq_cst:
298     return llvm::AtomicOrdering::SequentiallyConsistent;
299   }
300   llvm_unreachable("incorrect atomic ordering");
301 }
302 
303 ModuleTranslation::ModuleTranslation(Operation *module,
304                                      std::unique_ptr<llvm::Module> llvmModule)
305     : mlirModule(module), llvmModule(std::move(llvmModule)),
306       debugTranslation(
307           std::make_unique<DebugTranslation>(module, *this->llvmModule)),
308       ompDialect(module->getContext()->getLoadedDialect("omp")),
309       typeTranslator(this->llvmModule->getContext()) {
310   assert(satisfiesLLVMModule(mlirModule) &&
311          "mlirModule should honor LLVM's module semantics.");
312 }
313 ModuleTranslation::~ModuleTranslation() {
314   if (ompBuilder)
315     ompBuilder->finalize();
316 }
317 
318 /// Get the SSA value passed to the current block from the terminator operation
319 /// of its predecessor.
320 static Value getPHISourceValue(Block *current, Block *pred,
321                                unsigned numArguments, unsigned index) {
322   Operation &terminator = *pred->getTerminator();
323   if (isa<LLVM::BrOp>(terminator))
324     return terminator.getOperand(index);
325 
326   SuccessorRange successors = terminator.getSuccessors();
327   assert(std::adjacent_find(successors.begin(), successors.end()) ==
328              successors.end() &&
329          "successors with arguments in LLVM branches must be different blocks");
330   (void)successors;
331 
332   // For instructions that branch based on a condition value, we need to take
333   // the operands for the branch that was taken.
334   if (auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator)) {
335     // For conditional branches, we take the operands from either the "true" or
336     // the "false" branch.
337     return condBranchOp.getSuccessor(0) == current
338                ? condBranchOp.trueDestOperands()[index]
339                : condBranchOp.falseDestOperands()[index];
340   }
341 
342   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(terminator)) {
343     // For switches, we take the operands from either the default case, or from
344     // the case branch that was taken.
345     if (switchOp.defaultDestination() == current)
346       return switchOp.defaultOperands()[index];
347     for (auto i : llvm::enumerate(switchOp.caseDestinations()))
348       if (i.value() == current)
349         return switchOp.getCaseOperands(i.index())[index];
350   }
351 
352   llvm_unreachable("only branch or switch operations can be terminators of a "
353                    "block that has successors");
354 }
355 
356 /// Connect the PHI nodes to the results of preceding blocks.
357 template <typename T>
358 static void connectPHINodes(T &func, const ModuleTranslation &state) {
359   // Skip the first block, it cannot be branched to and its arguments correspond
360   // to the arguments of the LLVM function.
361   for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
362     Block *bb = &*it;
363     llvm::BasicBlock *llvmBB = state.lookupBlock(bb);
364     auto phis = llvmBB->phis();
365     auto numArguments = bb->getNumArguments();
366     assert(numArguments == std::distance(phis.begin(), phis.end()));
367     for (auto &numberedPhiNode : llvm::enumerate(phis)) {
368       auto &phiNode = numberedPhiNode.value();
369       unsigned index = numberedPhiNode.index();
370       for (auto *pred : bb->getPredecessors()) {
371         // Find the LLVM IR block that contains the converted terminator
372         // instruction and use it in the PHI node. Note that this block is not
373         // necessarily the same as state.lookupBlock(pred), some operations
374         // (in particular, OpenMP operations using OpenMPIRBuilder) may have
375         // split the blocks.
376         llvm::Instruction *terminator =
377             state.lookupBranch(pred->getTerminator());
378         assert(terminator && "missing the mapping for a terminator");
379         phiNode.addIncoming(
380             state.lookupValue(getPHISourceValue(bb, pred, numArguments, index)),
381             terminator->getParent());
382       }
383     }
384   }
385 }
386 
387 /// Sort function blocks topologically.
388 template <typename T>
389 static llvm::SetVector<Block *> topologicalSort(T &f) {
390   // For each block that has not been visited yet (i.e. that has no
391   // predecessors), add it to the list as well as its successors.
392   llvm::SetVector<Block *> blocks;
393   for (Block &b : f) {
394     if (blocks.count(&b) == 0) {
395       llvm::ReversePostOrderTraversal<Block *> traversal(&b);
396       blocks.insert(traversal.begin(), traversal.end());
397     }
398   }
399   assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
400 
401   return blocks;
402 }
403 
404 /// Convert the OpenMP parallel Operation to LLVM IR.
405 LogicalResult
406 ModuleTranslation::convertOmpParallel(Operation &opInst,
407                                       llvm::IRBuilder<> &builder) {
408   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
409   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
410   // relying on captured variables.
411   LogicalResult bodyGenStatus = success();
412 
413   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
414                        llvm::BasicBlock &continuationBlock) {
415     // ParallelOp has only one region associated with it.
416     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
417     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
418                         continuationBlock, builder, bodyGenStatus);
419   };
420 
421   // TODO: Perform appropriate actions according to the data-sharing
422   // attribute (shared, private, firstprivate, ...) of variables.
423   // Currently defaults to shared.
424   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
425                     llvm::Value &, llvm::Value &vPtr,
426                     llvm::Value *&replacementValue) -> InsertPointTy {
427     replacementValue = &vPtr;
428 
429     return codeGenIP;
430   };
431 
432   // TODO: Perform finalization actions for variables. This has to be
433   // called for variables which have destructors/finalizers.
434   auto finiCB = [&](InsertPointTy codeGenIP) {};
435 
436   llvm::Value *ifCond = nullptr;
437   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
438     ifCond = lookupValue(ifExprVar);
439   llvm::Value *numThreads = nullptr;
440   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
441     numThreads = lookupValue(numThreadsVar);
442   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
443   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
444     pbKind = llvm::omp::getProcBindKind(bind.getValue());
445   // TODO: Is the Parallel construct cancellable?
446   bool isCancellable = false;
447   // TODO: Determine the actual alloca insertion point, e.g., the function
448   // entry or the alloca insertion point as provided by the body callback
449   // above.
450   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
451   if (failed(bodyGenStatus))
452     return failure();
453   builder.restoreIP(
454       ompBuilder->createParallel(builder, allocaIP, bodyGenCB, privCB, finiCB,
455                                  ifCond, numThreads, pbKind, isCancellable));
456   return success();
457 }
458 
459 void ModuleTranslation::convertOmpOpRegions(
460     Region &region, StringRef blockName,
461     llvm::BasicBlock &sourceBlock, llvm::BasicBlock &continuationBlock,
462     llvm::IRBuilder<> &builder, LogicalResult &bodyGenStatus) {
463   llvm::LLVMContext &llvmContext = builder.getContext();
464   for (Block &bb : region) {
465     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
466         llvmContext, blockName, builder.GetInsertBlock()->getParent());
467     mapBlock(&bb, llvmBB);
468   }
469 
470   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
471 
472   // Convert blocks one by one in topological order to ensure
473   // defs are converted before uses.
474   llvm::SetVector<Block *> blocks = topologicalSort(region);
475   for (Block *bb : blocks) {
476     llvm::BasicBlock *llvmBB = lookupBlock(bb);
477     // Retarget the branch of the entry block to the entry block of the
478     // converted region (regions are single-entry).
479     if (bb->isEntryBlock()) {
480       assert(sourceTerminator->getNumSuccessors() == 1 &&
481              "provided entry block has multiple successors");
482       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
483              "ContinuationBlock is not the successor of the entry block");
484       sourceTerminator->setSuccessor(0, llvmBB);
485     }
486 
487     llvm::IRBuilder<>::InsertPointGuard guard(builder);
488     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder))) {
489       bodyGenStatus = failure();
490       return;
491     }
492 
493     // Special handling for `omp.yield` and `omp.terminator` (we may have more
494     // than one): they return the control to the parent OpenMP dialect operation
495     // so replace them with the branch to the continuation block. We handle this
496     // here to avoid relying inter-function communication through the
497     // ModuleTranslation class to set up the correct insertion point. This is
498     // also consistent with MLIR's idiom of handling special region terminators
499     // in the same code that handles the region-owning operation.
500     if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
501       builder.CreateBr(&continuationBlock);
502   }
503   // Finally, after all blocks have been traversed and values mapped,
504   // connect the PHI nodes to the results of preceding blocks.
505   connectPHINodes(region, *this);
506 }
507 
508 LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
509                                                   llvm::IRBuilder<> &builder) {
510   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
511   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
512   // relying on captured variables.
513   LogicalResult bodyGenStatus = success();
514 
515   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
516                        llvm::BasicBlock &continuationBlock) {
517     // MasterOp has only one region associated with it.
518     auto &region = cast<omp::MasterOp>(opInst).getRegion();
519     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
520                         continuationBlock, builder, bodyGenStatus);
521   };
522 
523   // TODO: Perform finalization actions for variables. This has to be
524   // called for variables which have destructors/finalizers.
525   auto finiCB = [&](InsertPointTy codeGenIP) {};
526 
527   builder.restoreIP(ompBuilder->createMaster(builder, bodyGenCB, finiCB));
528   return success();
529 }
530 
531 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
532 LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
533                                                   llvm::IRBuilder<> &builder) {
534   auto loop = cast<omp::WsLoopOp>(opInst);
535   // TODO: this should be in the op verifier instead.
536   if (loop.lowerBound().empty())
537     return failure();
538 
539   if (loop.getNumLoops() != 1)
540     return opInst.emitOpError("collapsed loops not yet supported");
541 
542   if (loop.schedule_val().hasValue() &&
543       omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
544           omp::ClauseScheduleKind::Static)
545     return opInst.emitOpError(
546         "only static (default) loop schedule is currently supported");
547 
548   // Find the loop configuration.
549   llvm::Value *lowerBound = lookupValue(loop.lowerBound()[0]);
550   llvm::Value *upperBound = lookupValue(loop.upperBound()[0]);
551   llvm::Value *step = lookupValue(loop.step()[0]);
552   llvm::Type *ivType = step->getType();
553   llvm::Value *chunk = loop.schedule_chunk_var()
554                            ? lookupValue(loop.schedule_chunk_var())
555                            : llvm::ConstantInt::get(ivType, 1);
556 
557   // Set up the source location value for OpenMP runtime.
558   llvm::DISubprogram *subprogram =
559       builder.GetInsertBlock()->getParent()->getSubprogram();
560   const llvm::DILocation *diLoc =
561       debugTranslation->translateLoc(opInst.getLoc(), subprogram);
562   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
563                                                     llvm::DebugLoc(diLoc));
564 
565   // Generator of the canonical loop body. Produces an SESE region of basic
566   // blocks.
567   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
568   // relying on captured variables.
569   LogicalResult bodyGenStatus = success();
570   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
571     llvm::IRBuilder<>::InsertPointGuard guard(builder);
572 
573     // Make sure further conversions know about the induction variable.
574     mapValue(loop.getRegion().front().getArgument(0), iv);
575 
576     llvm::BasicBlock *entryBlock = ip.getBlock();
577     llvm::BasicBlock *exitBlock =
578         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
579 
580     // Convert the body of the loop.
581     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
582                         *exitBlock, builder, bodyGenStatus);
583   };
584 
585   // Delegate actual loop construction to the OpenMP IRBuilder.
586   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
587   // i.e. it has a positive step, uses signed integer semantics. Reconsider
588   // this code when WsLoop clearly supports more cases.
589   llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
590   llvm::CanonicalLoopInfo *loopInfo = ompBuilder->createCanonicalLoop(
591       ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
592       /*InclusiveStop=*/loop.inclusive());
593   if (failed(bodyGenStatus))
594     return failure();
595 
596   // TODO: get the alloca insertion point from the parallel operation builder.
597   // If we insert the at the top of the current function, they will be passed as
598   // extra arguments into the function the parallel operation builder outlines.
599   // Put them at the start of the current block for now.
600   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
601       insertBlock, insertBlock->getFirstInsertionPt());
602   loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
603                                                    !loop.nowait(), chunk);
604 
605   // Continue building IR after the loop.
606   builder.restoreIP(loopInfo->getAfterIP());
607   return success();
608 }
609 
610 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
611 /// (including OpenMP runtime calls).
612 LogicalResult
613 ModuleTranslation::convertOmpOperation(Operation &opInst,
614                                        llvm::IRBuilder<> &builder) {
615   if (!ompBuilder) {
616     ompBuilder = std::make_unique<llvm::OpenMPIRBuilder>(*llvmModule);
617     ompBuilder->initialize();
618   }
619   return llvm::TypeSwitch<Operation *, LogicalResult>(&opInst)
620       .Case([&](omp::BarrierOp) {
621         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
622         return success();
623       })
624       .Case([&](omp::TaskwaitOp) {
625         ompBuilder->createTaskwait(builder.saveIP());
626         return success();
627       })
628       .Case([&](omp::TaskyieldOp) {
629         ompBuilder->createTaskyield(builder.saveIP());
630         return success();
631       })
632       .Case([&](omp::FlushOp) {
633         // No support in Openmp runtime function (__kmpc_flush) to accept
634         // the argument list.
635         // OpenMP standard states the following:
636         //  "An implementation may implement a flush with a list by ignoring
637         //   the list, and treating it the same as a flush without a list."
638         //
639         // The argument list is discarded so that, flush with a list is treated
640         // same as a flush without a list.
641         ompBuilder->createFlush(builder.saveIP());
642         return success();
643       })
644       .Case(
645           [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
646       .Case([&](omp::MasterOp) { return convertOmpMaster(opInst, builder); })
647       .Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); })
648       .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
649         // `yield` and `terminator` can be just omitted. The block structure was
650         // created in the function that handles their parent operation.
651         assert(op->getNumOperands() == 0 &&
652                "unexpected OpenMP terminator with operands");
653         return success();
654       })
655       .Default([&](Operation *inst) {
656         return inst->emitError("unsupported OpenMP operation: ")
657                << inst->getName();
658       });
659 }
660 
661 static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
662   using llvmFMF = llvm::FastMathFlags;
663   using FuncT = void (llvmFMF::*)(bool);
664   const std::pair<FastmathFlags, FuncT> handlers[] = {
665       // clang-format off
666       {FastmathFlags::nnan,     &llvmFMF::setNoNaNs},
667       {FastmathFlags::ninf,     &llvmFMF::setNoInfs},
668       {FastmathFlags::nsz,      &llvmFMF::setNoSignedZeros},
669       {FastmathFlags::arcp,     &llvmFMF::setAllowReciprocal},
670       {FastmathFlags::contract, &llvmFMF::setAllowContract},
671       {FastmathFlags::afn,      &llvmFMF::setApproxFunc},
672       {FastmathFlags::reassoc,  &llvmFMF::setAllowReassoc},
673       {FastmathFlags::fast,     &llvmFMF::setFast},
674       // clang-format on
675   };
676   llvm::FastMathFlags ret;
677   auto fmf = op.fastmathFlags();
678   for (auto it : handlers)
679     if (bitEnumContains(fmf, it.first))
680       (ret.*(it.second))(true);
681   return ret;
682 }
683 
684 /// Given a single MLIR operation, create the corresponding LLVM IR operation
685 /// using the `builder`.  LLVM IR Builder does not have a generic interface so
686 /// this has to be a long chain of `if`s calling different functions with a
687 /// different number of arguments.
688 LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
689                                                   llvm::IRBuilder<> &builder) {
690   auto extractPosition = [](ArrayAttr attr) {
691     SmallVector<unsigned, 4> position;
692     position.reserve(attr.size());
693     for (Attribute v : attr)
694       position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
695     return position;
696   };
697 
698   llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
699   if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
700     builder.setFastMathFlags(getFastmathFlags(fmf));
701 
702 #include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
703 
704   // Emit function calls.  If the "callee" attribute is present, this is a
705   // direct function call and we also need to look up the remapped function
706   // itself.  Otherwise, this is an indirect call and the callee is the first
707   // operand, look it up as a normal value.  Return the llvm::Value representing
708   // the function result, which may be of llvm::VoidTy type.
709   auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
710     auto operands = lookupValues(op.getOperands());
711     ArrayRef<llvm::Value *> operandsRef(operands);
712     if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
713       return builder.CreateCall(lookupFunction(attr.getValue()), operandsRef);
714     auto *calleePtrType =
715         cast<llvm::PointerType>(operandsRef.front()->getType());
716     auto *calleeType =
717         cast<llvm::FunctionType>(calleePtrType->getElementType());
718     return builder.CreateCall(calleeType, operandsRef.front(),
719                               operandsRef.drop_front());
720   };
721 
722   // Emit calls.  If the called function has a result, remap the corresponding
723   // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
724   if (isa<LLVM::CallOp>(opInst)) {
725     llvm::Value *result = convertCall(opInst);
726     if (opInst.getNumResults() != 0) {
727       mapValue(opInst.getResult(0), result);
728       return success();
729     }
730     // Check that LLVM call returns void for 0-result functions.
731     return success(result->getType()->isVoidTy());
732   }
733 
734   if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
735     // TODO: refactor function type creation which usually occurs in std-LLVM
736     // conversion.
737     SmallVector<Type, 8> operandTypes;
738     operandTypes.reserve(inlineAsmOp.operands().size());
739     for (auto t : inlineAsmOp.operands().getTypes())
740       operandTypes.push_back(t);
741 
742     Type resultType;
743     if (inlineAsmOp.getNumResults() == 0) {
744       resultType = LLVM::LLVMVoidType::get(mlirModule->getContext());
745     } else {
746       assert(inlineAsmOp.getNumResults() == 1);
747       resultType = inlineAsmOp.getResultTypes()[0];
748     }
749     auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
750     llvm::InlineAsm *inlineAsmInst =
751         inlineAsmOp.asm_dialect().hasValue()
752             ? llvm::InlineAsm::get(
753                   static_cast<llvm::FunctionType *>(convertType(ft)),
754                   inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
755                   inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack(),
756                   convertAsmDialectToLLVM(*inlineAsmOp.asm_dialect()))
757             : llvm::InlineAsm::get(
758                   static_cast<llvm::FunctionType *>(convertType(ft)),
759                   inlineAsmOp.asm_string(), inlineAsmOp.constraints(),
760                   inlineAsmOp.has_side_effects(), inlineAsmOp.is_align_stack());
761     llvm::Value *result =
762         builder.CreateCall(inlineAsmInst, lookupValues(inlineAsmOp.operands()));
763     if (opInst.getNumResults() != 0)
764       mapValue(opInst.getResult(0), result);
765     return success();
766   }
767 
768   if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
769     auto operands = lookupValues(opInst.getOperands());
770     ArrayRef<llvm::Value *> operandsRef(operands);
771     if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
772       builder.CreateInvoke(lookupFunction(attr.getValue()),
773                            lookupBlock(invOp.getSuccessor(0)),
774                            lookupBlock(invOp.getSuccessor(1)), operandsRef);
775     } else {
776       auto *calleePtrType =
777           cast<llvm::PointerType>(operandsRef.front()->getType());
778       auto *calleeType =
779           cast<llvm::FunctionType>(calleePtrType->getElementType());
780       builder.CreateInvoke(
781           calleeType, operandsRef.front(), lookupBlock(invOp.getSuccessor(0)),
782           lookupBlock(invOp.getSuccessor(1)), operandsRef.drop_front());
783     }
784     return success();
785   }
786 
787   if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
788     llvm::Type *ty = convertType(lpOp.getType());
789     llvm::LandingPadInst *lpi =
790         builder.CreateLandingPad(ty, lpOp.getNumOperands());
791 
792     // Add clauses
793     for (llvm::Value *operand : lookupValues(lpOp.getOperands())) {
794       // All operands should be constant - checked by verifier
795       if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
796         lpi->addClause(constOperand);
797     }
798     mapValue(lpOp.getResult(), lpi);
799     return success();
800   }
801 
802   // Emit branches.  We need to look up the remapped blocks and ignore the block
803   // arguments that were transformed into PHI nodes.
804   if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
805     llvm::BranchInst *branch =
806         builder.CreateBr(lookupBlock(brOp.getSuccessor()));
807     mapBranch(&opInst, branch);
808     return success();
809   }
810   if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
811     auto weights = condbrOp.branch_weights();
812     llvm::MDNode *branchWeights = nullptr;
813     if (weights) {
814       // Map weight attributes to LLVM metadata.
815       auto trueWeight =
816           weights.getValue().getValue(0).cast<IntegerAttr>().getInt();
817       auto falseWeight =
818           weights.getValue().getValue(1).cast<IntegerAttr>().getInt();
819       branchWeights =
820           llvm::MDBuilder(llvmModule->getContext())
821               .createBranchWeights(static_cast<uint32_t>(trueWeight),
822                                    static_cast<uint32_t>(falseWeight));
823     }
824     llvm::BranchInst *branch = builder.CreateCondBr(
825         lookupValue(condbrOp.getOperand(0)),
826         lookupBlock(condbrOp.getSuccessor(0)),
827         lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
828     mapBranch(&opInst, branch);
829     return success();
830   }
831   if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
832     llvm::MDNode *branchWeights = nullptr;
833     if (auto weights = switchOp.branch_weights()) {
834       llvm::SmallVector<uint32_t> weightValues;
835       weightValues.reserve(weights->size());
836       for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
837         weightValues.push_back(weight.getLimitedValue());
838       branchWeights = llvm::MDBuilder(llvmModule->getContext())
839                           .createBranchWeights(weightValues);
840     }
841 
842     llvm::SwitchInst *switchInst =
843         builder.CreateSwitch(lookupValue(switchOp.value()),
844                              lookupBlock(switchOp.defaultDestination()),
845                              switchOp.caseDestinations().size(), branchWeights);
846 
847     auto *ty =
848         llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
849     for (auto i :
850          llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
851                    switchOp.caseDestinations()))
852       switchInst->addCase(
853           llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
854           lookupBlock(std::get<1>(i)));
855 
856     mapBranch(&opInst, switchInst);
857     return success();
858   }
859 
860   // Emit addressof.  We need to look up the global value referenced by the
861   // operation and store it in the MLIR-to-LLVM value mapping.  This does not
862   // emit any LLVM instruction.
863   if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
864     LLVM::GlobalOp global = addressOfOp.getGlobal();
865     LLVM::LLVMFuncOp function = addressOfOp.getFunction();
866 
867     // The verifier should not have allowed this.
868     assert((global || function) &&
869            "referencing an undefined global or function");
870 
871     mapValue(addressOfOp.getResult(), global
872                                           ? globalsMapping.lookup(global)
873                                           : lookupFunction(function.getName()));
874     return success();
875   }
876 
877   if (ompDialect && opInst.getDialect() == ompDialect)
878     return convertOmpOperation(opInst, builder);
879 
880   return opInst.emitError("unsupported or non-LLVM operation: ")
881          << opInst.getName();
882 }
883 
884 /// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
885 /// to define values corresponding to the MLIR block arguments.  These nodes
886 /// are not connected to the source basic blocks, which may not exist yet.  Uses
887 /// `builder` to construct the LLVM IR. Expects the LLVM IR basic block to have
888 /// been created for `bb` and included in the block mapping.  Inserts new
889 /// instructions at the end of the block and leaves `builder` in a state
890 /// suitable for further insertion into the end of the block.
891 LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments,
892                                               llvm::IRBuilder<> &builder) {
893   builder.SetInsertPoint(lookupBlock(&bb));
894   auto *subprogram = builder.GetInsertBlock()->getParent()->getSubprogram();
895 
896   // Before traversing operations, make block arguments available through
897   // value remapping and PHI nodes, but do not add incoming edges for the PHI
898   // nodes just yet: those values may be defined by this or following blocks.
899   // This step is omitted if "ignoreArguments" is set.  The arguments of the
900   // first block have been already made available through the remapping of
901   // LLVM function arguments.
902   if (!ignoreArguments) {
903     auto predecessors = bb.getPredecessors();
904     unsigned numPredecessors =
905         std::distance(predecessors.begin(), predecessors.end());
906     for (auto arg : bb.getArguments()) {
907       auto wrappedType = arg.getType();
908       if (!isCompatibleType(wrappedType))
909         return emitError(bb.front().getLoc(),
910                          "block argument does not have an LLVM type");
911       llvm::Type *type = convertType(wrappedType);
912       llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
913       mapValue(arg, phi);
914     }
915   }
916 
917   // Traverse operations.
918   for (auto &op : bb) {
919     // Set the current debug location within the builder.
920     builder.SetCurrentDebugLocation(
921         debugTranslation->translateLoc(op.getLoc(), subprogram));
922 
923     if (failed(convertOperation(op, builder)))
924       return failure();
925   }
926 
927   return success();
928 }
929 
930 /// Create named global variables that correspond to llvm.mlir.global
931 /// definitions.
932 LogicalResult ModuleTranslation::convertGlobals() {
933   for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
934     llvm::Type *type = convertType(op.getType());
935     llvm::Constant *cst = llvm::UndefValue::get(type);
936     if (op.getValueOrNull()) {
937       // String attributes are treated separately because they cannot appear as
938       // in-function constants and are thus not supported by getLLVMConstant.
939       if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
940         cst = llvm::ConstantDataArray::getString(
941             llvmModule->getContext(), strAttr.getValue(), /*AddNull=*/false);
942         type = cst->getType();
943       } else if (!(cst = getLLVMConstant(type, op.getValueOrNull(),
944                                          op.getLoc()))) {
945         return failure();
946       }
947     } else if (Block *initializer = op.getInitializerBlock()) {
948       llvm::IRBuilder<> builder(llvmModule->getContext());
949       for (auto &op : initializer->without_terminator()) {
950         if (failed(convertOperation(op, builder)) ||
951             !isa<llvm::Constant>(lookupValue(op.getResult(0))))
952           return emitError(op.getLoc(), "unemittable constant value");
953       }
954       ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
955       cst = cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
956     }
957 
958     auto linkage = convertLinkageToLLVM(op.linkage());
959     bool anyExternalLinkage =
960         ((linkage == llvm::GlobalVariable::ExternalLinkage &&
961           isa<llvm::UndefValue>(cst)) ||
962          linkage == llvm::GlobalVariable::ExternalWeakLinkage);
963     auto addrSpace = op.addr_space();
964     auto *var = new llvm::GlobalVariable(
965         *llvmModule, type, op.constant(), linkage,
966         anyExternalLinkage ? nullptr : cst, op.sym_name(),
967         /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal, addrSpace);
968 
969     globalsMapping.try_emplace(op, var);
970   }
971 
972   return success();
973 }
974 
975 /// Attempts to add an attribute identified by `key`, optionally with the given
976 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
977 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
978 /// otherwise keep it as a string attribute. Performs additional checks for
979 /// attributes known to have or not have a value in order to avoid assertions
980 /// inside LLVM upon construction.
981 static LogicalResult checkedAddLLVMFnAttribute(Location loc,
982                                                llvm::Function *llvmFunc,
983                                                StringRef key,
984                                                StringRef value = StringRef()) {
985   auto kind = llvm::Attribute::getAttrKindFromName(key);
986   if (kind == llvm::Attribute::None) {
987     llvmFunc->addFnAttr(key, value);
988     return success();
989   }
990 
991   if (llvm::Attribute::doesAttrKindHaveArgument(kind)) {
992     if (value.empty())
993       return emitError(loc) << "LLVM attribute '" << key << "' expects a value";
994 
995     int result;
996     if (!value.getAsInteger(/*Radix=*/0, result))
997       llvmFunc->addFnAttr(
998           llvm::Attribute::get(llvmFunc->getContext(), kind, result));
999     else
1000       llvmFunc->addFnAttr(key, value);
1001     return success();
1002   }
1003 
1004   if (!value.empty())
1005     return emitError(loc) << "LLVM attribute '" << key
1006                           << "' does not expect a value, found '" << value
1007                           << "'";
1008 
1009   llvmFunc->addFnAttr(kind);
1010   return success();
1011 }
1012 
1013 /// Attaches the attributes listed in the given array attribute to `llvmFunc`.
1014 /// Reports error to `loc` if any and returns immediately. Expects `attributes`
1015 /// to be an array attribute containing either string attributes, treated as
1016 /// value-less LLVM attributes, or array attributes containing two string
1017 /// attributes, with the first string being the name of the corresponding LLVM
1018 /// attribute and the second string beings its value. Note that even integer
1019 /// attributes are expected to have their values expressed as strings.
1020 static LogicalResult
1021 forwardPassthroughAttributes(Location loc, Optional<ArrayAttr> attributes,
1022                              llvm::Function *llvmFunc) {
1023   if (!attributes)
1024     return success();
1025 
1026   for (Attribute attr : *attributes) {
1027     if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
1028       if (failed(
1029               checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue())))
1030         return failure();
1031       continue;
1032     }
1033 
1034     auto arrayAttr = attr.dyn_cast<ArrayAttr>();
1035     if (!arrayAttr || arrayAttr.size() != 2)
1036       return emitError(loc)
1037              << "expected 'passthrough' to contain string or array attributes";
1038 
1039     auto keyAttr = arrayAttr[0].dyn_cast<StringAttr>();
1040     auto valueAttr = arrayAttr[1].dyn_cast<StringAttr>();
1041     if (!keyAttr || !valueAttr)
1042       return emitError(loc)
1043              << "expected arrays within 'passthrough' to contain two strings";
1044 
1045     if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(),
1046                                          valueAttr.getValue())))
1047       return failure();
1048   }
1049   return success();
1050 }
1051 
1052 LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
1053   // Clear the block, branch value mappings, they are only relevant within one
1054   // function.
1055   blockMapping.clear();
1056   valueMapping.clear();
1057   branchMapping.clear();
1058   llvm::Function *llvmFunc = lookupFunction(func.getName());
1059 
1060   // Translate the debug information for this function.
1061   debugTranslation->translate(func, *llvmFunc);
1062 
1063   // Add function arguments to the value remapping table.
1064   // If there was noalias info then we decorate each argument accordingly.
1065   unsigned int argIdx = 0;
1066   for (auto kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
1067     llvm::Argument &llvmArg = std::get<1>(kvp);
1068     BlockArgument mlirArg = std::get<0>(kvp);
1069 
1070     if (auto attr = func.getArgAttrOfType<BoolAttr>(
1071             argIdx, LLVMDialect::getNoAliasAttrName())) {
1072       // NB: Attribute already verified to be boolean, so check if we can indeed
1073       // attach the attribute to this argument, based on its type.
1074       auto argTy = mlirArg.getType();
1075       if (!argTy.isa<LLVM::LLVMPointerType>())
1076         return func.emitError(
1077             "llvm.noalias attribute attached to LLVM non-pointer argument");
1078       if (attr.getValue())
1079         llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
1080     }
1081 
1082     if (auto attr = func.getArgAttrOfType<IntegerAttr>(
1083             argIdx, LLVMDialect::getAlignAttrName())) {
1084       // NB: Attribute already verified to be int, so check if we can indeed
1085       // attach the attribute to this argument, based on its type.
1086       auto argTy = mlirArg.getType();
1087       if (!argTy.isa<LLVM::LLVMPointerType>())
1088         return func.emitError(
1089             "llvm.align attribute attached to LLVM non-pointer argument");
1090       llvmArg.addAttrs(
1091           llvm::AttrBuilder().addAlignmentAttr(llvm::Align(attr.getInt())));
1092     }
1093 
1094     if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.sret")) {
1095       auto argTy = mlirArg.getType();
1096       if (!argTy.isa<LLVM::LLVMPointerType>())
1097         return func.emitError(
1098             "llvm.sret attribute attached to LLVM non-pointer argument");
1099       llvmArg.addAttrs(llvm::AttrBuilder().addStructRetAttr(
1100           llvmArg.getType()->getPointerElementType()));
1101     }
1102 
1103     if (auto attr = func.getArgAttrOfType<UnitAttr>(argIdx, "llvm.byval")) {
1104       auto argTy = mlirArg.getType();
1105       if (!argTy.isa<LLVM::LLVMPointerType>())
1106         return func.emitError(
1107             "llvm.byval attribute attached to LLVM non-pointer argument");
1108       llvmArg.addAttrs(llvm::AttrBuilder().addByValAttr(
1109           llvmArg.getType()->getPointerElementType()));
1110     }
1111 
1112     mapValue(mlirArg, &llvmArg);
1113     argIdx++;
1114   }
1115 
1116   // Check the personality and set it.
1117   if (func.personality().hasValue()) {
1118     llvm::Type *ty = llvm::Type::getInt8PtrTy(llvmFunc->getContext());
1119     if (llvm::Constant *pfunc =
1120             getLLVMConstant(ty, func.personalityAttr(), func.getLoc()))
1121       llvmFunc->setPersonalityFn(pfunc);
1122   }
1123 
1124   // First, create all blocks so we can jump to them.
1125   llvm::LLVMContext &llvmContext = llvmFunc->getContext();
1126   for (auto &bb : func) {
1127     auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
1128     llvmBB->insertInto(llvmFunc);
1129     mapBlock(&bb, llvmBB);
1130   }
1131 
1132   // Then, convert blocks one by one in topological order to ensure defs are
1133   // converted before uses.
1134   auto blocks = topologicalSort(func);
1135   for (Block *bb : blocks) {
1136     llvm::IRBuilder<> builder(llvmContext);
1137     if (failed(convertBlock(*bb, bb->isEntryBlock(), builder)))
1138       return failure();
1139   }
1140 
1141   // Finally, after all blocks have been traversed and values mapped, connect
1142   // the PHI nodes to the results of preceding blocks.
1143   connectPHINodes(func, *this);
1144   return success();
1145 }
1146 
1147 LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
1148   for (Operation &o : getModuleBody(m).getOperations())
1149     if (!isa<LLVM::LLVMFuncOp, LLVM::GlobalOp>(&o) &&
1150         !o.hasTrait<OpTrait::IsTerminator>())
1151       return o.emitOpError("unsupported module-level operation");
1152   return success();
1153 }
1154 
1155 LogicalResult ModuleTranslation::convertFunctionSignatures() {
1156   // Declare all functions first because there may be function calls that form a
1157   // call graph with cycles, or global initializers that reference functions.
1158   for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1159     llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction(
1160         function.getName(),
1161         cast<llvm::FunctionType>(convertType(function.getType())));
1162     llvm::Function *llvmFunc = cast<llvm::Function>(llvmFuncCst.getCallee());
1163     llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage()));
1164     mapFunction(function.getName(), llvmFunc);
1165 
1166     // Forward the pass-through attributes to LLVM.
1167     if (failed(forwardPassthroughAttributes(function.getLoc(),
1168                                             function.passthrough(), llvmFunc)))
1169       return failure();
1170   }
1171 
1172   return success();
1173 }
1174 
1175 LogicalResult ModuleTranslation::convertFunctions() {
1176   // Convert functions.
1177   for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
1178     // Ignore external functions.
1179     if (function.isExternal())
1180       continue;
1181 
1182     if (failed(convertOneFunction(function)))
1183       return failure();
1184   }
1185 
1186   return success();
1187 }
1188 
1189 llvm::Type *ModuleTranslation::convertType(Type type) {
1190   return typeTranslator.translateType(type);
1191 }
1192 
1193 /// A helper to look up remapped operands in the value remapping table.`
1194 SmallVector<llvm::Value *, 8>
1195 ModuleTranslation::lookupValues(ValueRange values) {
1196   SmallVector<llvm::Value *, 8> remapped;
1197   remapped.reserve(values.size());
1198   for (Value v : values)
1199     remapped.push_back(lookupValue(v));
1200   return remapped;
1201 }
1202 
1203 std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
1204     Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
1205   m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
1206   auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
1207   if (auto dataLayoutAttr =
1208           m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()))
1209     llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());
1210   if (auto targetTripleAttr =
1211           m->getAttr(LLVM::LLVMDialect::getTargetTripleAttrName()))
1212     llvmModule->setTargetTriple(targetTripleAttr.cast<StringAttr>().getValue());
1213 
1214   // Inject declarations for `malloc` and `free` functions that can be used in
1215   // memref allocation/deallocation coming from standard ops lowering.
1216   llvm::IRBuilder<> builder(llvmContext);
1217   llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
1218                                   builder.getInt64Ty());
1219   llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
1220                                   builder.getInt8PtrTy());
1221 
1222   return llvmModule;
1223 }
1224