1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
19 
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
23 #include "llvm/IR/DebugInfoMetadata.h"
24 #include "llvm/IR/IRBuilder.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
30 /// insertion points for allocas.
31 class OpenMPAllocaStackFrame
32     : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
33 public:
34   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
35 
36   explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
37       : allocaInsertPoint(allocaIP) {}
38   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
39 };
40 
41 /// ModuleTranslation stack frame containing the partial mapping between MLIR
42 /// values and their LLVM IR equivalents.
43 class OpenMPVarMappingStackFrame
44     : public LLVM::ModuleTranslation::StackFrameBase<
45           OpenMPVarMappingStackFrame> {
46 public:
47   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame)
48 
49   explicit OpenMPVarMappingStackFrame(
50       const DenseMap<Value, llvm::Value *> &mapping)
51       : mapping(mapping) {}
52 
53   DenseMap<Value, llvm::Value *> mapping;
54 };
55 } // namespace
56 
57 /// Find the insertion point for allocas given the current insertion point for
58 /// normal operations in the builder.
59 static llvm::OpenMPIRBuilder::InsertPointTy
60 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
61                       const LLVM::ModuleTranslation &moduleTranslation) {
62   // If there is an alloca insertion point on stack, i.e. we are in a nested
63   // operation and a specific point was provided by some surrounding operation,
64   // use it.
65   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
66   WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
67       [&](const OpenMPAllocaStackFrame &frame) {
68         allocaInsertPoint = frame.allocaInsertPoint;
69         return WalkResult::interrupt();
70       });
71   if (walkResult.wasInterrupted())
72     return allocaInsertPoint;
73 
74   // Otherwise, insert to the entry block of the surrounding function.
75   llvm::BasicBlock &funcEntryBlock =
76       builder.GetInsertBlock()->getParent()->getEntryBlock();
77   return llvm::OpenMPIRBuilder::InsertPointTy(
78       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
79 }
80 
81 /// Converts the given region that appears within an OpenMP dialect operation to
82 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
83 /// region, and a branch from any block with an successor-less OpenMP terminator
84 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
85 /// of the continuation block if provided.
86 static void convertOmpOpRegions(
87     Region &region, StringRef blockName, llvm::BasicBlock &sourceBlock,
88     llvm::BasicBlock &continuationBlock, llvm::IRBuilderBase &builder,
89     LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
90     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
91   llvm::LLVMContext &llvmContext = builder.getContext();
92   for (Block &bb : region) {
93     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
94         llvmContext, blockName, builder.GetInsertBlock()->getParent(),
95         builder.GetInsertBlock()->getNextNode());
96     moduleTranslation.mapBlock(&bb, llvmBB);
97   }
98 
99   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
100 
101   // Terminators (namely YieldOp) may be forwarding values to the region that
102   // need to be available in the continuation block. Collect the types of these
103   // operands in preparation of creating PHI nodes.
104   SmallVector<llvm::Type *> continuationBlockPHITypes;
105   bool operandsProcessed = false;
106   unsigned numYields = 0;
107   for (Block &bb : region.getBlocks()) {
108     if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
109       if (!operandsProcessed) {
110         for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
111           continuationBlockPHITypes.push_back(
112               moduleTranslation.convertType(yield->getOperand(i).getType()));
113         }
114         operandsProcessed = true;
115       } else {
116         assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
117                "mismatching number of values yielded from the region");
118         for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
119           llvm::Type *operandType =
120               moduleTranslation.convertType(yield->getOperand(i).getType());
121           (void)operandType;
122           assert(continuationBlockPHITypes[i] == operandType &&
123                  "values of mismatching types yielded from the region");
124         }
125       }
126       numYields++;
127     }
128   }
129 
130   // Insert PHI nodes in the continuation block for any values forwarded by the
131   // terminators in this region.
132   if (!continuationBlockPHITypes.empty())
133     assert(
134         continuationBlockPHIs &&
135         "expected continuation block PHIs if converted regions yield values");
136   if (continuationBlockPHIs) {
137     llvm::IRBuilderBase::InsertPointGuard guard(builder);
138     continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
139     builder.SetInsertPoint(&continuationBlock, continuationBlock.begin());
140     for (llvm::Type *ty : continuationBlockPHITypes)
141       continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
142   }
143 
144   // Convert blocks one by one in topological order to ensure
145   // defs are converted before uses.
146   SetVector<Block *> blocks =
147       LLVM::detail::getTopologicallySortedBlocks(region);
148   for (Block *bb : blocks) {
149     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
150     // Retarget the branch of the entry block to the entry block of the
151     // converted region (regions are single-entry).
152     if (bb->isEntryBlock()) {
153       assert(sourceTerminator->getNumSuccessors() == 1 &&
154              "provided entry block has multiple successors");
155       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
156              "ContinuationBlock is not the successor of the entry block");
157       sourceTerminator->setSuccessor(0, llvmBB);
158     }
159 
160     llvm::IRBuilderBase::InsertPointGuard guard(builder);
161     if (failed(
162             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
163       bodyGenStatus = failure();
164       return;
165     }
166 
167     // Special handling for `omp.yield` and `omp.terminator` (we may have more
168     // than one): they return the control to the parent OpenMP dialect operation
169     // so replace them with the branch to the continuation block. We handle this
170     // here to avoid relying inter-function communication through the
171     // ModuleTranslation class to set up the correct insertion point. This is
172     // also consistent with MLIR's idiom of handling special region terminators
173     // in the same code that handles the region-owning operation.
174     Operation *terminator = bb->getTerminator();
175     if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
176       builder.CreateBr(&continuationBlock);
177 
178       for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
179         (*continuationBlockPHIs)[i]->addIncoming(
180             moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
181     }
182   }
183   // After all blocks have been traversed and values mapped, connect the PHI
184   // nodes to the results of preceding blocks.
185   LLVM::detail::connectPHINodes(region, moduleTranslation);
186 
187   // Remove the blocks and values defined in this region from the mapping since
188   // they are not visible outside of this region. This allows the same region to
189   // be converted several times, that is cloned, without clashes, and slightly
190   // speeds up the lookups.
191   moduleTranslation.forgetMapping(region);
192 }
193 
194 /// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
195 static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
196   switch (kind) {
197   case omp::ClauseProcBindKind::Close:
198     return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
199   case omp::ClauseProcBindKind::Master:
200     return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
201   case omp::ClauseProcBindKind::Primary:
202     return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
203   case omp::ClauseProcBindKind::Spread:
204     return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
205   }
206   llvm_unreachable("Unknown ClauseProcBindKind kind");
207 }
208 
209 /// Converts the OpenMP parallel operation to LLVM IR.
210 static LogicalResult
211 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
212                    LLVM::ModuleTranslation &moduleTranslation) {
213   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
214   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
215   // relying on captured variables.
216   LogicalResult bodyGenStatus = success();
217 
218   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
219                        llvm::BasicBlock &continuationBlock) {
220     // Save the alloca insertion point on ModuleTranslation stack for use in
221     // nested regions.
222     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
223         moduleTranslation, allocaIP);
224 
225     // ParallelOp has only one region associated with it.
226     convertOmpOpRegions(opInst.getRegion(), "omp.par.region",
227                         *codeGenIP.getBlock(), continuationBlock, builder,
228                         moduleTranslation, bodyGenStatus);
229   };
230 
231   // TODO: Perform appropriate actions according to the data-sharing
232   // attribute (shared, private, firstprivate, ...) of variables.
233   // Currently defaults to shared.
234   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
235                     llvm::Value &, llvm::Value &vPtr,
236                     llvm::Value *&replacementValue) -> InsertPointTy {
237     replacementValue = &vPtr;
238 
239     return codeGenIP;
240   };
241 
242   // TODO: Perform finalization actions for variables. This has to be
243   // called for variables which have destructors/finalizers.
244   auto finiCB = [&](InsertPointTy codeGenIP) {};
245 
246   llvm::Value *ifCond = nullptr;
247   if (auto ifExprVar = opInst.if_expr_var())
248     ifCond = moduleTranslation.lookupValue(ifExprVar);
249   llvm::Value *numThreads = nullptr;
250   if (auto numThreadsVar = opInst.num_threads_var())
251     numThreads = moduleTranslation.lookupValue(numThreadsVar);
252   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
253   if (auto bind = opInst.proc_bind_val())
254     pbKind = getProcBindKind(*bind);
255   // TODO: Is the Parallel construct cancellable?
256   bool isCancellable = false;
257 
258   // Ensure that the BasicBlock for the the parallel region is sparate from the
259   // function entry which we may need to insert allocas.
260   if (builder.GetInsertBlock() ==
261       &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
262     assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
263            "Assuming end of basic block");
264     llvm::BasicBlock *entryBB =
265         llvm::BasicBlock::Create(builder.getContext(), "parallel.entry",
266                                  builder.GetInsertBlock()->getParent(),
267                                  builder.GetInsertBlock()->getNextNode());
268     builder.CreateBr(entryBB);
269     builder.SetInsertPoint(entryBB);
270   }
271   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
272   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
273       ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
274       privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
275 
276   return bodyGenStatus;
277 }
278 
279 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
280 static LogicalResult
281 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
282                  LLVM::ModuleTranslation &moduleTranslation) {
283   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
284   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
285   // relying on captured variables.
286   LogicalResult bodyGenStatus = success();
287 
288   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
289                        llvm::BasicBlock &continuationBlock) {
290     // MasterOp has only one region associated with it.
291     auto &region = cast<omp::MasterOp>(opInst).getRegion();
292     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
293                         continuationBlock, builder, moduleTranslation,
294                         bodyGenStatus);
295   };
296 
297   // TODO: Perform finalization actions for variables. This has to be
298   // called for variables which have destructors/finalizers.
299   auto finiCB = [&](InsertPointTy codeGenIP) {};
300 
301   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
302   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
303       ompLoc, bodyGenCB, finiCB));
304   return success();
305 }
306 
307 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
308 static LogicalResult
309 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
310                    LLVM::ModuleTranslation &moduleTranslation) {
311   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
312   auto criticalOp = cast<omp::CriticalOp>(opInst);
313   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
314   // relying on captured variables.
315   LogicalResult bodyGenStatus = success();
316 
317   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
318                        llvm::BasicBlock &continuationBlock) {
319     // CriticalOp has only one region associated with it.
320     auto &region = cast<omp::CriticalOp>(opInst).getRegion();
321     convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(),
322                         continuationBlock, builder, moduleTranslation,
323                         bodyGenStatus);
324   };
325 
326   // TODO: Perform finalization actions for variables. This has to be
327   // called for variables which have destructors/finalizers.
328   auto finiCB = [&](InsertPointTy codeGenIP) {};
329 
330   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
331   llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
332   llvm::Constant *hint = nullptr;
333 
334   // If it has a name, it probably has a hint too.
335   if (criticalOp.nameAttr()) {
336     // The verifiers in OpenMP Dialect guarentee that all the pointers are
337     // non-null
338     auto symbolRef = criticalOp.nameAttr().cast<SymbolRefAttr>();
339     auto criticalDeclareOp =
340         SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
341                                                                      symbolRef);
342     hint =
343         llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
344                                static_cast<int>(criticalDeclareOp.hint_val()));
345   }
346   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
347       ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
348   return success();
349 }
350 
351 /// Returns a reduction declaration that corresponds to the given reduction
352 /// operation in the given container. Currently only supports reductions inside
353 /// WsLoopOp but can be easily extended.
354 static omp::ReductionDeclareOp findReductionDecl(omp::WsLoopOp container,
355                                                  omp::ReductionOp reduction) {
356   SymbolRefAttr reductionSymbol;
357   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) {
358     if (container.reduction_vars()[i] != reduction.accumulator())
359       continue;
360     reductionSymbol = (*container.reductions())[i].cast<SymbolRefAttr>();
361     break;
362   }
363   assert(reductionSymbol &&
364          "reduction operation must be associated with a declaration");
365 
366   return SymbolTable::lookupNearestSymbolFrom<omp::ReductionDeclareOp>(
367       container, reductionSymbol);
368 }
369 
370 /// Populates `reductions` with reduction declarations used in the given loop.
371 static void
372 collectReductionDecls(omp::WsLoopOp loop,
373                       SmallVectorImpl<omp::ReductionDeclareOp> &reductions) {
374   Optional<ArrayAttr> attr = loop.reductions();
375   if (!attr)
376     return;
377 
378   reductions.reserve(reductions.size() + loop.getNumReductionVars());
379   for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
380     reductions.push_back(
381         SymbolTable::lookupNearestSymbolFrom<omp::ReductionDeclareOp>(
382             loop, symbolRef));
383   }
384 }
385 
386 /// Translates the blocks contained in the given region and appends them to at
387 /// the current insertion point of `builder`. The operations of the entry block
388 /// are appended to the current insertion block, which is not expected to have a
389 /// terminator. If set, `continuationBlockArgs` is populated with translated
390 /// values that correspond to the values omp.yield'ed from the region.
391 static LogicalResult inlineConvertOmpRegions(
392     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
393     LLVM::ModuleTranslation &moduleTranslation,
394     SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
395   if (region.empty())
396     return success();
397 
398   // Special case for single-block regions that don't create additional blocks:
399   // insert operations without creating additional blocks.
400   if (llvm::hasSingleElement(region)) {
401     moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
402     if (failed(moduleTranslation.convertBlock(
403             region.front(), /*ignoreArguments=*/true, builder)))
404       return failure();
405 
406     // The continuation arguments are simply the translated terminator operands.
407     if (continuationBlockArgs)
408       llvm::append_range(
409           *continuationBlockArgs,
410           moduleTranslation.lookupValues(region.front().back().getOperands()));
411 
412     // Drop the mapping that is no longer necessary so that the same region can
413     // be processed multiple times.
414     moduleTranslation.forgetMapping(region);
415     return success();
416   }
417 
418   // Create the continuation block manually instead of calling splitBlock
419   // because the current insertion block may not have a terminator.
420   llvm::BasicBlock *continuationBlock =
421       llvm::BasicBlock::Create(builder.getContext(), blockName + ".cont",
422                                builder.GetInsertBlock()->getParent(),
423                                builder.GetInsertBlock()->getNextNode());
424   builder.CreateBr(continuationBlock);
425 
426   LogicalResult bodyGenStatus = success();
427   SmallVector<llvm::PHINode *> phis;
428   convertOmpOpRegions(region, blockName, *builder.GetInsertBlock(),
429                       *continuationBlock, builder, moduleTranslation,
430                       bodyGenStatus, &phis);
431   if (failed(bodyGenStatus))
432     return failure();
433   if (continuationBlockArgs)
434     llvm::append_range(*continuationBlockArgs, phis);
435   builder.SetInsertPoint(continuationBlock,
436                          continuationBlock->getFirstInsertionPt());
437   return success();
438 }
439 
440 namespace {
441 /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
442 /// store lambdas with capture.
443 using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
444     llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
445     llvm::Value *&)>;
446 using OwningAtomicReductionGen =
447     std::function<llvm::OpenMPIRBuilder::InsertPointTy(
448         llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
449         llvm::Value *)>;
450 } // namespace
451 
452 /// Create an OpenMPIRBuilder-compatible reduction generator for the given
453 /// reduction declaration. The generator uses `builder` but ignores its
454 /// insertion point.
455 static OwningReductionGen
456 makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
457                  LLVM::ModuleTranslation &moduleTranslation) {
458   // The lambda is mutable because we need access to non-const methods of decl
459   // (which aren't actually mutating it), and we must capture decl by-value to
460   // avoid the dangling reference after the parent function returns.
461   OwningReductionGen gen =
462       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
463                 llvm::Value *lhs, llvm::Value *rhs,
464                 llvm::Value *&result) mutable {
465         Region &reductionRegion = decl.reductionRegion();
466         moduleTranslation.mapValue(reductionRegion.front().getArgument(0), lhs);
467         moduleTranslation.mapValue(reductionRegion.front().getArgument(1), rhs);
468         builder.restoreIP(insertPoint);
469         SmallVector<llvm::Value *> phis;
470         if (failed(inlineConvertOmpRegions(reductionRegion,
471                                            "omp.reduction.nonatomic.body",
472                                            builder, moduleTranslation, &phis)))
473           return llvm::OpenMPIRBuilder::InsertPointTy();
474         assert(phis.size() == 1);
475         result = phis[0];
476         return builder.saveIP();
477       };
478   return gen;
479 }
480 
481 /// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
482 /// given reduction declaration. The generator uses `builder` but ignores its
483 /// insertion point. Returns null if there is no atomic region available in the
484 /// reduction declaration.
485 static OwningAtomicReductionGen
486 makeAtomicReductionGen(omp::ReductionDeclareOp decl,
487                        llvm::IRBuilderBase &builder,
488                        LLVM::ModuleTranslation &moduleTranslation) {
489   if (decl.atomicReductionRegion().empty())
490     return OwningAtomicReductionGen();
491 
492   // The lambda is mutable because we need access to non-const methods of decl
493   // (which aren't actually mutating it), and we must capture decl by-value to
494   // avoid the dangling reference after the parent function returns.
495   OwningAtomicReductionGen atomicGen =
496       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
497                 llvm::Value *lhs, llvm::Value *rhs) mutable {
498         Region &atomicRegion = decl.atomicReductionRegion();
499         moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
500         moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs);
501         builder.restoreIP(insertPoint);
502         SmallVector<llvm::Value *> phis;
503         if (failed(inlineConvertOmpRegions(atomicRegion,
504                                            "omp.reduction.atomic.body", builder,
505                                            moduleTranslation, &phis)))
506           return llvm::OpenMPIRBuilder::InsertPointTy();
507         assert(phis.empty());
508         return builder.saveIP();
509       };
510   return atomicGen;
511 }
512 
513 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
514 static LogicalResult
515 convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
516                   LLVM::ModuleTranslation &moduleTranslation) {
517   auto orderedOp = cast<omp::OrderedOp>(opInst);
518 
519   omp::ClauseDepend dependType = *orderedOp.depend_type_val();
520   bool isDependSource = dependType == omp::ClauseDepend::dependsource;
521   unsigned numLoops = orderedOp.num_loops_val().getValue();
522   SmallVector<llvm::Value *> vecValues =
523       moduleTranslation.lookupValues(orderedOp.depend_vec_vars());
524 
525   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
526   size_t indexVecValues = 0;
527   while (indexVecValues < vecValues.size()) {
528     SmallVector<llvm::Value *> storeValues;
529     storeValues.reserve(numLoops);
530     for (unsigned i = 0; i < numLoops; i++) {
531       storeValues.push_back(vecValues[indexVecValues]);
532       indexVecValues++;
533     }
534     builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
535         ompLoc, findAllocaInsertPoint(builder, moduleTranslation), numLoops,
536         storeValues, ".cnt.addr", isDependSource));
537   }
538   return success();
539 }
540 
541 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using
542 /// OpenMPIRBuilder.
543 static LogicalResult
544 convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
545                         LLVM::ModuleTranslation &moduleTranslation) {
546   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
547   auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
548 
549   // TODO: The code generation for ordered simd directive is not supported yet.
550   if (orderedRegionOp.simd())
551     return failure();
552 
553   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
554   // relying on captured variables.
555   LogicalResult bodyGenStatus = success();
556 
557   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
558                        llvm::BasicBlock &continuationBlock) {
559     // OrderedOp has only one region associated with it.
560     auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
561     convertOmpOpRegions(region, "omp.ordered.region", *codeGenIP.getBlock(),
562                         continuationBlock, builder, moduleTranslation,
563                         bodyGenStatus);
564   };
565 
566   // TODO: Perform finalization actions for variables. This has to be
567   // called for variables which have destructors/finalizers.
568   auto finiCB = [&](InsertPointTy codeGenIP) {};
569 
570   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
571   builder.restoreIP(
572       moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
573           ompLoc, bodyGenCB, finiCB, !orderedRegionOp.simd()));
574   return bodyGenStatus;
575 }
576 
577 static LogicalResult
578 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
579                    LLVM::ModuleTranslation &moduleTranslation) {
580   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
581   using StorableBodyGenCallbackTy =
582       llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
583 
584   auto sectionsOp = cast<omp::SectionsOp>(opInst);
585 
586   // TODO: Support the following clauses: private, firstprivate, lastprivate,
587   // reduction, allocate
588   if (!sectionsOp.reduction_vars().empty() || sectionsOp.reductions() ||
589       !sectionsOp.allocate_vars().empty() ||
590       !sectionsOp.allocators_vars().empty())
591     return emitError(sectionsOp.getLoc())
592            << "reduction and allocate clauses are not supported for sections "
593               "construct";
594 
595   LogicalResult bodyGenStatus = success();
596   SmallVector<StorableBodyGenCallbackTy> sectionCBs;
597 
598   for (Operation &op : *sectionsOp.region().begin()) {
599     auto sectionOp = dyn_cast<omp::SectionOp>(op);
600     if (!sectionOp) // omp.terminator
601       continue;
602 
603     Region &region = sectionOp.region();
604     auto sectionCB = [&region, &builder, &moduleTranslation, &bodyGenStatus](
605                          InsertPointTy allocaIP, InsertPointTy codeGenIP,
606                          llvm::BasicBlock &finiBB) {
607       builder.restoreIP(codeGenIP);
608       builder.CreateBr(&finiBB);
609       convertOmpOpRegions(region, "omp.section.region", *codeGenIP.getBlock(),
610                           finiBB, builder, moduleTranslation, bodyGenStatus);
611     };
612     sectionCBs.push_back(sectionCB);
613   }
614 
615   // No sections within omp.sections operation - skip generation. This situation
616   // is only possible if there is only a terminator operation inside the
617   // sections operation
618   if (sectionCBs.empty())
619     return success();
620 
621   assert(isa<omp::SectionOp>(*sectionsOp.region().op_begin()));
622 
623   // TODO: Perform appropriate actions according to the data-sharing
624   // attribute (shared, private, firstprivate, ...) of variables.
625   // Currently defaults to shared.
626   auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
627                     llvm::Value &vPtr,
628                     llvm::Value *&replacementValue) -> InsertPointTy {
629     replacementValue = &vPtr;
630     return codeGenIP;
631   };
632 
633   // TODO: Perform finalization actions for variables. This has to be
634   // called for variables which have destructors/finalizers.
635   auto finiCB = [&](InsertPointTy codeGenIP) {};
636 
637   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
638   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSections(
639       ompLoc, findAllocaInsertPoint(builder, moduleTranslation), sectionCBs,
640       privCB, finiCB, false, sectionsOp.nowait()));
641   return bodyGenStatus;
642 }
643 
644 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
645 static LogicalResult
646 convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
647                  LLVM::ModuleTranslation &moduleTranslation) {
648   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
649   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
650   LogicalResult bodyGenStatus = success();
651   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
652                     llvm::BasicBlock &continuationBB) {
653     convertOmpOpRegions(singleOp.region(), "omp.single.region",
654                         *codegenIP.getBlock(), continuationBB, builder,
655                         moduleTranslation, bodyGenStatus);
656   };
657   auto finiCB = [&](InsertPointTy codeGenIP) {};
658   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
659       ompLoc, bodyCB, finiCB, singleOp.nowait(), /*DidIt=*/nullptr));
660   return bodyGenStatus;
661 }
662 
663 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
664 static LogicalResult
665 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
666                  LLVM::ModuleTranslation &moduleTranslation) {
667   auto loop = cast<omp::WsLoopOp>(opInst);
668   // TODO: this should be in the op verifier instead.
669   if (loop.lowerBound().empty())
670     return failure();
671 
672   // Static is the default.
673   auto schedule =
674       loop.schedule_val().getValueOr(omp::ClauseScheduleKind::Static);
675 
676   // Find the loop configuration.
677   llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
678   llvm::Type *ivType = step->getType();
679   llvm::Value *chunk = nullptr;
680   if (loop.schedule_chunk_var()) {
681     llvm::Value *chunkVar =
682         moduleTranslation.lookupValue(loop.schedule_chunk_var());
683     llvm::Type *chunkVarType = chunkVar->getType();
684     assert(chunkVarType->isIntegerTy() &&
685            "chunk size must be one integer expression");
686     if (chunkVarType->getIntegerBitWidth() < ivType->getIntegerBitWidth())
687       chunk = builder.CreateSExt(chunkVar, ivType);
688     else if (chunkVarType->getIntegerBitWidth() > ivType->getIntegerBitWidth())
689       chunk = builder.CreateTrunc(chunkVar, ivType);
690     else
691       chunk = chunkVar;
692   }
693 
694   SmallVector<omp::ReductionDeclareOp> reductionDecls;
695   collectReductionDecls(loop, reductionDecls);
696   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
697       findAllocaInsertPoint(builder, moduleTranslation);
698 
699   // Allocate space for privatized reduction variables.
700   SmallVector<llvm::Value *> privateReductionVariables;
701   DenseMap<Value, llvm::Value *> reductionVariableMap;
702   unsigned numReductions = loop.getNumReductionVars();
703   privateReductionVariables.reserve(numReductions);
704   if (numReductions != 0) {
705     llvm::IRBuilderBase::InsertPointGuard guard(builder);
706     builder.restoreIP(allocaIP);
707     for (unsigned i = 0; i < numReductions; ++i) {
708       auto reductionType =
709           loop.reduction_vars()[i].getType().cast<LLVM::LLVMPointerType>();
710       llvm::Value *var = builder.CreateAlloca(
711           moduleTranslation.convertType(reductionType.getElementType()));
712       privateReductionVariables.push_back(var);
713       reductionVariableMap.try_emplace(loop.reduction_vars()[i], var);
714     }
715   }
716 
717   // Store the mapping between reduction variables and their private copies on
718   // ModuleTranslation stack. It can be then recovered when translating
719   // omp.reduce operations in a separate call.
720   LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
721       moduleTranslation, reductionVariableMap);
722 
723   // Before the loop, store the initial values of reductions into reduction
724   // variables. Although this could be done after allocas, we don't want to mess
725   // up with the alloca insertion point.
726   for (unsigned i = 0; i < numReductions; ++i) {
727     SmallVector<llvm::Value *> phis;
728     if (failed(inlineConvertOmpRegions(reductionDecls[i].initializerRegion(),
729                                        "omp.reduction.neutral", builder,
730                                        moduleTranslation, &phis)))
731       return failure();
732     assert(phis.size() == 1 && "expected one value to be yielded from the "
733                                "reduction neutral element declaration region");
734     builder.CreateStore(phis[0], privateReductionVariables[i]);
735   }
736 
737   // Set up the source location value for OpenMP runtime.
738   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
739 
740   // Generator of the canonical loop body.
741   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
742   // relying on captured variables.
743   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
744   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
745   LogicalResult bodyGenStatus = success();
746   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
747     // Make sure further conversions know about the induction variable.
748     moduleTranslation.mapValue(
749         loop.getRegion().front().getArgument(loopInfos.size()), iv);
750 
751     // Capture the body insertion point for use in nested loops. BodyIP of the
752     // CanonicalLoopInfo always points to the beginning of the entry block of
753     // the body.
754     bodyInsertPoints.push_back(ip);
755 
756     if (loopInfos.size() != loop.getNumLoops() - 1)
757       return;
758 
759     // Convert the body of the loop.
760     llvm::BasicBlock *entryBlock = ip.getBlock();
761     llvm::BasicBlock *exitBlock =
762         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
763     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
764                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
765   };
766 
767   // Delegate actual loop construction to the OpenMP IRBuilder.
768   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
769   // i.e. it has a positive step, uses signed integer semantics. Reconsider
770   // this code when WsLoop clearly supports more cases.
771   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
772   for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
773     llvm::Value *lowerBound =
774         moduleTranslation.lookupValue(loop.lowerBound()[i]);
775     llvm::Value *upperBound =
776         moduleTranslation.lookupValue(loop.upperBound()[i]);
777     llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
778 
779     // Make sure loop trip count are emitted in the preheader of the outermost
780     // loop at the latest so that they are all available for the new collapsed
781     // loop will be created below.
782     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
783     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
784     if (i != 0) {
785       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
786       computeIP = loopInfos.front()->getPreheaderIP();
787     }
788     loopInfos.push_back(ompBuilder->createCanonicalLoop(
789         loc, bodyGen, lowerBound, upperBound, step,
790         /*IsSigned=*/true, loop.inclusive(), computeIP));
791 
792     if (failed(bodyGenStatus))
793       return failure();
794   }
795 
796   // Collapse loops. Store the insertion point because LoopInfos may get
797   // invalidated.
798   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
799   llvm::CanonicalLoopInfo *loopInfo =
800       ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
801 
802   allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
803 
804   bool isSimd = loop.simd_modifier();
805 
806   // The orderedVal refers to the value obtained from the ordered[(n)] clause.
807   //   orderedVal == -1: No ordered[(n)] clause specified.
808   //   orderedVal == 0: The ordered clause specified without a parameter.
809   //   orderedVal > 0: The ordered clause specified with a parameter (n).
810   // TODO: Handle doacross loop init when orderedVal is greater than 0.
811   int64_t orderedVal =
812       loop.ordered_val().hasValue() ? loop.ordered_val().getValue() : -1;
813   if (schedule == omp::ClauseScheduleKind::Static && orderedVal != 0) {
814     ompBuilder->applyWorkshareLoop(ompLoc.DL, loopInfo, allocaIP,
815                                    !loop.nowait(),
816                                    llvm::omp::OMP_SCHEDULE_Static, chunk);
817   } else {
818     llvm::omp::OMPScheduleType schedType;
819     switch (schedule) {
820     case omp::ClauseScheduleKind::Static:
821       if (loop.schedule_chunk_var())
822         schedType = llvm::omp::OMPScheduleType::OrderedStaticChunked;
823       else
824         schedType = llvm::omp::OMPScheduleType::OrderedStatic;
825       break;
826     case omp::ClauseScheduleKind::Dynamic:
827       if (orderedVal == 0)
828         schedType = llvm::omp::OMPScheduleType::OrderedDynamicChunked;
829       else
830         schedType = llvm::omp::OMPScheduleType::DynamicChunked;
831       break;
832     case omp::ClauseScheduleKind::Guided:
833       if (orderedVal == 0) {
834         schedType = llvm::omp::OMPScheduleType::OrderedGuidedChunked;
835       } else {
836         if (isSimd)
837           schedType = llvm::omp::OMPScheduleType::GuidedSimd;
838         else
839           schedType = llvm::omp::OMPScheduleType::GuidedChunked;
840       }
841       break;
842     case omp::ClauseScheduleKind::Auto:
843       if (orderedVal == 0)
844         schedType = llvm::omp::OMPScheduleType::OrderedAuto;
845       else
846         schedType = llvm::omp::OMPScheduleType::Auto;
847       break;
848     case omp::ClauseScheduleKind::Runtime:
849       if (orderedVal == 0) {
850         schedType = llvm::omp::OMPScheduleType::OrderedRuntime;
851       } else {
852         if (isSimd)
853           schedType = llvm::omp::OMPScheduleType::RuntimeSimd;
854         else
855           schedType = llvm::omp::OMPScheduleType::Runtime;
856       }
857       break;
858     }
859 
860     if (Optional<omp::ScheduleModifier> modifier = loop.schedule_modifier()) {
861       switch (*modifier) {
862       case omp::ScheduleModifier::monotonic:
863         schedType |= llvm::omp::OMPScheduleType::ModifierMonotonic;
864         break;
865       case omp::ScheduleModifier::nonmonotonic:
866         schedType |= llvm::omp::OMPScheduleType::ModifierNonmonotonic;
867         break;
868       default:
869         // Nothing to do here.
870         break;
871       }
872     } else {
873       // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
874       // If the static schedule kind is specified or if the ordered clause is
875       // specified, and if the nonmonotonic modifier is not specified, the
876       // effect is as if the monotonic modifier is specified. Otherwise, unless
877       // the monotonic modifier is specified, the effect is as if the
878       // nonmonotonic modifier is specified.
879       // The monotonic is used by default in openmp runtime library, so no need
880       // to set it.
881       if (!(schedType == llvm::omp::OMPScheduleType::OrderedStatic ||
882             schedType == llvm::omp::OMPScheduleType::OrderedStaticChunked))
883         schedType |= llvm::omp::OMPScheduleType::ModifierNonmonotonic;
884     }
885 
886     ompBuilder->applyDynamicWorkshareLoop(ompLoc.DL, loopInfo, allocaIP,
887                                           schedType, !loop.nowait(), chunk,
888                                           /*ordered*/ orderedVal == 0);
889   }
890 
891   // Continue building IR after the loop. Note that the LoopInfo returned by
892   // `collapseLoops` points inside the outermost loop and is intended for
893   // potential further loop transformations. Use the insertion point stored
894   // before collapsing loops instead.
895   builder.restoreIP(afterIP);
896 
897   // Process the reductions if required.
898   if (numReductions == 0)
899     return success();
900 
901   // Create the reduction generators. We need to own them here because
902   // ReductionInfo only accepts references to the generators.
903   SmallVector<OwningReductionGen> owningReductionGens;
904   SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
905   for (unsigned i = 0; i < numReductions; ++i) {
906     owningReductionGens.push_back(
907         makeReductionGen(reductionDecls[i], builder, moduleTranslation));
908     owningAtomicReductionGens.push_back(
909         makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
910   }
911 
912   // Collect the reduction information.
913   SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
914   reductionInfos.reserve(numReductions);
915   for (unsigned i = 0; i < numReductions; ++i) {
916     llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
917     if (owningAtomicReductionGens[i])
918       atomicGen = owningAtomicReductionGens[i];
919     auto reductionType =
920         loop.reduction_vars()[i].getType().cast<LLVM::LLVMPointerType>();
921     llvm::Value *variable =
922         moduleTranslation.lookupValue(loop.reduction_vars()[i]);
923     reductionInfos.push_back(
924         {moduleTranslation.convertType(reductionType.getElementType()),
925          variable, privateReductionVariables[i], owningReductionGens[i],
926          atomicGen});
927   }
928 
929   // The call to createReductions below expects the block to have a
930   // terminator. Create an unreachable instruction to serve as terminator
931   // and remove it later.
932   llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
933   builder.SetInsertPoint(tempTerminator);
934   llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
935       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
936                                    loop.nowait());
937   if (!contInsertPoint.getBlock())
938     return loop->emitOpError() << "failed to convert reductions";
939   auto nextInsertionPoint =
940       ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
941   tempTerminator->eraseFromParent();
942   builder.restoreIP(nextInsertionPoint);
943 
944   return success();
945 }
946 
947 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
948 static LogicalResult
949 convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
950                    LLVM::ModuleTranslation &moduleTranslation) {
951   auto loop = cast<omp::SimdLoopOp>(opInst);
952 
953   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
954 
955   // Generator of the canonical loop body.
956   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
957   // relying on captured variables.
958   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
959   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
960   LogicalResult bodyGenStatus = success();
961   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
962     // Make sure further conversions know about the induction variable.
963     moduleTranslation.mapValue(
964         loop.getRegion().front().getArgument(loopInfos.size()), iv);
965 
966     // Capture the body insertion point for use in nested loops. BodyIP of the
967     // CanonicalLoopInfo always points to the beginning of the entry block of
968     // the body.
969     bodyInsertPoints.push_back(ip);
970 
971     if (loopInfos.size() != loop.getNumLoops() - 1)
972       return;
973 
974     // Convert the body of the loop.
975     llvm::BasicBlock *entryBlock = ip.getBlock();
976     llvm::BasicBlock *exitBlock =
977         entryBlock->splitBasicBlock(ip.getPoint(), "omp.simdloop.exit");
978     convertOmpOpRegions(loop.region(), "omp.simdloop.region", *entryBlock,
979                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
980   };
981 
982   // Delegate actual loop construction to the OpenMP IRBuilder.
983   // TODO: this currently assumes SimdLoop is semantically similar to SCF loop,
984   // i.e. it has a positive step, uses signed integer semantics. Reconsider
985   // this code when SimdLoop clearly supports more cases.
986   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
987   for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
988     llvm::Value *lowerBound =
989         moduleTranslation.lookupValue(loop.lowerBound()[i]);
990     llvm::Value *upperBound =
991         moduleTranslation.lookupValue(loop.upperBound()[i]);
992     llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
993 
994     // Make sure loop trip count are emitted in the preheader of the outermost
995     // loop at the latest so that they are all available for the new collapsed
996     // loop will be created below.
997     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
998     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
999     if (i != 0) {
1000       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
1001                                                        ompLoc.DL);
1002       computeIP = loopInfos.front()->getPreheaderIP();
1003     }
1004     loopInfos.push_back(ompBuilder->createCanonicalLoop(
1005         loc, bodyGen, lowerBound, upperBound, step,
1006         /*IsSigned=*/true, /*Inclusive=*/true, computeIP));
1007 
1008     if (failed(bodyGenStatus))
1009       return failure();
1010   }
1011 
1012   // Collapse loops.
1013   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
1014   llvm::CanonicalLoopInfo *loopInfo =
1015       ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
1016 
1017   ompBuilder->applySimd(ompLoc.DL, loopInfo);
1018 
1019   builder.restoreIP(afterIP);
1020   return success();
1021 }
1022 
1023 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
1024 llvm::AtomicOrdering
1025 convertAtomicOrdering(Optional<omp::ClauseMemoryOrderKind> ao) {
1026   if (!ao)
1027     return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
1028 
1029   switch (*ao) {
1030   case omp::ClauseMemoryOrderKind::Seq_cst:
1031     return llvm::AtomicOrdering::SequentiallyConsistent;
1032   case omp::ClauseMemoryOrderKind::Acq_rel:
1033     return llvm::AtomicOrdering::AcquireRelease;
1034   case omp::ClauseMemoryOrderKind::Acquire:
1035     return llvm::AtomicOrdering::Acquire;
1036   case omp::ClauseMemoryOrderKind::Release:
1037     return llvm::AtomicOrdering::Release;
1038   case omp::ClauseMemoryOrderKind::Relaxed:
1039     return llvm::AtomicOrdering::Monotonic;
1040   }
1041   llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
1042 }
1043 
1044 /// Convert omp.atomic.read operation to LLVM IR.
1045 static LogicalResult
1046 convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
1047                      LLVM::ModuleTranslation &moduleTranslation) {
1048 
1049   auto readOp = cast<omp::AtomicReadOp>(opInst);
1050   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1051 
1052   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1053 
1054   llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order_val());
1055   llvm::Value *x = moduleTranslation.lookupValue(readOp.x());
1056   Type xTy = readOp.x().getType().cast<omp::PointerLikeType>().getElementType();
1057   llvm::Value *v = moduleTranslation.lookupValue(readOp.v());
1058   Type vTy = readOp.v().getType().cast<omp::PointerLikeType>().getElementType();
1059   llvm::OpenMPIRBuilder::AtomicOpValue V = {
1060       v, moduleTranslation.convertType(vTy), false, false};
1061   llvm::OpenMPIRBuilder::AtomicOpValue X = {
1062       x, moduleTranslation.convertType(xTy), false, false};
1063   builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
1064   return success();
1065 }
1066 
1067 /// Converts an omp.atomic.write operation to LLVM IR.
1068 static LogicalResult
1069 convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
1070                       LLVM::ModuleTranslation &moduleTranslation) {
1071   auto writeOp = cast<omp::AtomicWriteOp>(opInst);
1072   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1073 
1074   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1075   llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.memory_order_val());
1076   llvm::Value *expr = moduleTranslation.lookupValue(writeOp.value());
1077   llvm::Value *dest = moduleTranslation.lookupValue(writeOp.address());
1078   llvm::Type *ty = moduleTranslation.convertType(writeOp.value().getType());
1079   llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
1080                                             /*isVolatile=*/false};
1081   builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
1082   return success();
1083 }
1084 
1085 /// Converts an LLVM dialect binary operation to the corresponding enum value
1086 /// for `atomicrmw` supported binary operation.
1087 llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
1088   return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
1089       .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
1090       .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
1091       .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
1092       .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
1093       .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
1094       .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
1095       .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
1096       .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
1097       .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
1098       .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
1099 }
1100 
1101 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
1102 static LogicalResult
1103 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
1104                        llvm::IRBuilderBase &builder,
1105                        LLVM::ModuleTranslation &moduleTranslation) {
1106   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1107   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1108 
1109   // Convert values and types.
1110   auto &innerOpList = opInst.region().front().getOperations();
1111   if (innerOpList.size() != 2)
1112     return opInst.emitError("exactly two operations are allowed inside an "
1113                             "atomic update region while lowering to LLVM IR");
1114 
1115   Operation &innerUpdateOp = innerOpList.front();
1116 
1117   if (innerUpdateOp.getNumOperands() != 2 ||
1118       !llvm::is_contained(innerUpdateOp.getOperands(),
1119                           opInst.getRegion().getArgument(0)))
1120     return opInst.emitError(
1121         "the update operation inside the region must be a binary operation and "
1122         "that update operation must have the region argument as an operand");
1123 
1124   llvm::AtomicRMWInst::BinOp binop = convertBinOpToAtomic(innerUpdateOp);
1125 
1126   bool isXBinopExpr =
1127       innerUpdateOp.getNumOperands() > 0 &&
1128       innerUpdateOp.getOperand(0) == opInst.getRegion().getArgument(0);
1129 
1130   mlir::Value mlirExpr = (isXBinopExpr ? innerUpdateOp.getOperand(1)
1131                                        : innerUpdateOp.getOperand(0));
1132   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
1133   llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.x());
1134   LLVM::LLVMPointerType mlirXType =
1135       opInst.x().getType().cast<LLVM::LLVMPointerType>();
1136   llvm::Type *llvmXElementType =
1137       moduleTranslation.convertType(mlirXType.getElementType());
1138   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1139                                                       /*isSigned=*/false,
1140                                                       /*isVolatile=*/false};
1141 
1142   llvm::AtomicOrdering atomicOrdering =
1143       convertAtomicOrdering(opInst.memory_order_val());
1144 
1145   // Generate update code.
1146   LogicalResult updateGenStatus = success();
1147   auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus](
1148                       llvm::Value *atomicx,
1149                       llvm::IRBuilder<> &builder) -> llvm::Value * {
1150     Block &bb = *opInst.region().begin();
1151     moduleTranslation.mapValue(*opInst.region().args_begin(), atomicx);
1152     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
1153     if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
1154       updateGenStatus = (opInst.emitError()
1155                          << "unable to convert update operation to llvm IR");
1156       return nullptr;
1157     }
1158     omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
1159     assert(yieldop && yieldop.results().size() == 1 &&
1160            "terminator must be omp.yield op and it must have exactly one "
1161            "argument");
1162     return moduleTranslation.lookupValue(yieldop.results()[0]);
1163   };
1164 
1165   // Handle ambiguous alloca, if any.
1166   auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1167   if (allocaIP.getPoint() == ompLoc.IP.getPoint()) {
1168     // Same point => split basic block and make them unambigous.
1169     llvm::UnreachableInst *unreachableInst = builder.CreateUnreachable();
1170     builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock(
1171         unreachableInst, "alloca_split"));
1172     ompLoc.IP = builder.saveIP();
1173     unreachableInst->eraseFromParent();
1174   }
1175   builder.restoreIP(ompBuilder->createAtomicUpdate(
1176       ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX,
1177       llvmExpr, atomicOrdering, binop, updateFn, isXBinopExpr));
1178   return updateGenStatus;
1179 }
1180 
1181 static LogicalResult
1182 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
1183                         llvm::IRBuilderBase &builder,
1184                         LLVM::ModuleTranslation &moduleTranslation) {
1185   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1186   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1187   mlir::Value mlirExpr;
1188   bool isXBinopExpr = false, isPostfixUpdate = false;
1189   llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
1190 
1191   omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
1192   omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
1193 
1194   assert((atomicUpdateOp || atomicWriteOp) &&
1195          "internal op must be an atomic.update or atomic.write op");
1196 
1197   if (atomicWriteOp) {
1198     isPostfixUpdate = true;
1199     mlirExpr = atomicWriteOp.value();
1200   } else {
1201     isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
1202                       atomicCaptureOp.getAtomicUpdateOp().getOperation();
1203     auto &innerOpList = atomicUpdateOp.region().front().getOperations();
1204     if (innerOpList.size() != 2)
1205       return atomicUpdateOp.emitError(
1206           "exactly two operations are allowed inside an "
1207           "atomic update region while lowering to LLVM IR");
1208     Operation *innerUpdateOp = atomicUpdateOp.getFirstOp();
1209     if (innerUpdateOp->getNumOperands() != 2 ||
1210         !llvm::is_contained(innerUpdateOp->getOperands(),
1211                             atomicUpdateOp.getRegion().getArgument(0)))
1212       return atomicUpdateOp.emitError(
1213           "the update operation inside the region must be a binary operation "
1214           "and that update operation must have the region argument as an "
1215           "operand");
1216     binop = convertBinOpToAtomic(*innerUpdateOp);
1217 
1218     isXBinopExpr = innerUpdateOp->getOperand(0) ==
1219                    atomicUpdateOp.getRegion().getArgument(0);
1220 
1221     mlirExpr = (isXBinopExpr ? innerUpdateOp->getOperand(1)
1222                              : innerUpdateOp->getOperand(0));
1223   }
1224 
1225   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
1226   llvm::Value *llvmX =
1227       moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().x());
1228   llvm::Value *llvmV =
1229       moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().v());
1230   auto mlirXType = atomicCaptureOp.getAtomicReadOp()
1231                        .x()
1232                        .getType()
1233                        .cast<LLVM::LLVMPointerType>();
1234   llvm::Type *llvmXElementType =
1235       moduleTranslation.convertType(mlirXType.getElementType());
1236   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1237                                                       /*isSigned=*/false,
1238                                                       /*isVolatile=*/false};
1239   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
1240                                                       /*isSigned=*/false,
1241                                                       /*isVolatile=*/false};
1242 
1243   llvm::AtomicOrdering atomicOrdering =
1244       convertAtomicOrdering(atomicCaptureOp.memory_order_val());
1245 
1246   LogicalResult updateGenStatus = success();
1247   auto updateFn = [&](llvm::Value *atomicx,
1248                       llvm::IRBuilder<> &builder) -> llvm::Value * {
1249     if (atomicWriteOp)
1250       return moduleTranslation.lookupValue(atomicWriteOp.value());
1251     Block &bb = *atomicUpdateOp.region().begin();
1252     moduleTranslation.mapValue(*atomicUpdateOp.region().args_begin(), atomicx);
1253     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
1254     if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
1255       updateGenStatus = (atomicUpdateOp.emitError()
1256                          << "unable to convert update operation to llvm IR");
1257       return nullptr;
1258     }
1259     omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
1260     assert(yieldop && yieldop.results().size() == 1 &&
1261            "terminator must be omp.yield op and it must have exactly one "
1262            "argument");
1263     return moduleTranslation.lookupValue(yieldop.results()[0]);
1264   };
1265   // Handle ambiguous alloca, if any.
1266   auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1267   if (allocaIP.getPoint() == ompLoc.IP.getPoint()) {
1268     // Same point => split basic block and make them unambigous.
1269     llvm::UnreachableInst *unreachableInst = builder.CreateUnreachable();
1270     builder.SetInsertPoint(builder.GetInsertBlock()->splitBasicBlock(
1271         unreachableInst, "alloca_split"));
1272     ompLoc.IP = builder.saveIP();
1273     unreachableInst->eraseFromParent();
1274   }
1275   builder.restoreIP(ompBuilder->createAtomicCapture(
1276       ompLoc, findAllocaInsertPoint(builder, moduleTranslation), llvmAtomicX,
1277       llvmAtomicV, llvmExpr, atomicOrdering, binop, updateFn, atomicUpdateOp,
1278       isPostfixUpdate, isXBinopExpr));
1279   return updateGenStatus;
1280 }
1281 
1282 /// Converts an OpenMP reduction operation using OpenMPIRBuilder. Expects the
1283 /// mapping between reduction variables and their private equivalents to have
1284 /// been stored on the ModuleTranslation stack. Currently only supports
1285 /// reduction within WsLoopOp, but can be easily extended.
1286 static LogicalResult
1287 convertOmpReductionOp(omp::ReductionOp reductionOp,
1288                       llvm::IRBuilderBase &builder,
1289                       LLVM::ModuleTranslation &moduleTranslation) {
1290   // Find the declaration that corresponds to the reduction op.
1291   auto reductionContainer = reductionOp->getParentOfType<omp::WsLoopOp>();
1292   omp::ReductionDeclareOp declaration =
1293       findReductionDecl(reductionContainer, reductionOp);
1294   assert(declaration && "could not find reduction declaration");
1295 
1296   // Retrieve the mapping between reduction variables and their private
1297   // equivalents.
1298   const DenseMap<Value, llvm::Value *> *reductionVariableMap = nullptr;
1299   moduleTranslation.stackWalk<OpenMPVarMappingStackFrame>(
1300       [&](const OpenMPVarMappingStackFrame &frame) {
1301         reductionVariableMap = &frame.mapping;
1302         return WalkResult::interrupt();
1303       });
1304   assert(reductionVariableMap && "couldn't find private reduction variables");
1305 
1306   // Translate the reduction operation by emitting the body of the corresponding
1307   // reduction declaration.
1308   Region &reductionRegion = declaration.reductionRegion();
1309   llvm::Value *privateReductionVar =
1310       reductionVariableMap->lookup(reductionOp.accumulator());
1311   llvm::Value *reductionVal = builder.CreateLoad(
1312       moduleTranslation.convertType(reductionOp.operand().getType()),
1313       privateReductionVar);
1314 
1315   moduleTranslation.mapValue(reductionRegion.front().getArgument(0),
1316                              reductionVal);
1317   moduleTranslation.mapValue(
1318       reductionRegion.front().getArgument(1),
1319       moduleTranslation.lookupValue(reductionOp.operand()));
1320 
1321   SmallVector<llvm::Value *> phis;
1322   if (failed(inlineConvertOmpRegions(reductionRegion, "omp.reduction.body",
1323                                      builder, moduleTranslation, &phis)))
1324     return failure();
1325   assert(phis.size() == 1 && "expected one value to be yielded from "
1326                              "the reduction body declaration region");
1327   builder.CreateStore(phis[0], privateReductionVar);
1328   return success();
1329 }
1330 
1331 namespace {
1332 
1333 /// Implementation of the dialect interface that converts operations belonging
1334 /// to the OpenMP dialect to LLVM IR.
1335 class OpenMPDialectLLVMIRTranslationInterface
1336     : public LLVMTranslationDialectInterface {
1337 public:
1338   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
1339 
1340   /// Translates the given operation to LLVM IR using the provided IR builder
1341   /// and saving the state in `moduleTranslation`.
1342   LogicalResult
1343   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
1344                    LLVM::ModuleTranslation &moduleTranslation) const final;
1345 };
1346 
1347 } // namespace
1348 
1349 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
1350 /// (including OpenMP runtime calls).
1351 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
1352     Operation *op, llvm::IRBuilderBase &builder,
1353     LLVM::ModuleTranslation &moduleTranslation) const {
1354 
1355   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1356 
1357   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
1358       .Case([&](omp::BarrierOp) {
1359         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1360         return success();
1361       })
1362       .Case([&](omp::TaskwaitOp) {
1363         ompBuilder->createTaskwait(builder.saveIP());
1364         return success();
1365       })
1366       .Case([&](omp::TaskyieldOp) {
1367         ompBuilder->createTaskyield(builder.saveIP());
1368         return success();
1369       })
1370       .Case([&](omp::FlushOp) {
1371         // No support in Openmp runtime function (__kmpc_flush) to accept
1372         // the argument list.
1373         // OpenMP standard states the following:
1374         //  "An implementation may implement a flush with a list by ignoring
1375         //   the list, and treating it the same as a flush without a list."
1376         //
1377         // The argument list is discarded so that, flush with a list is treated
1378         // same as a flush without a list.
1379         ompBuilder->createFlush(builder.saveIP());
1380         return success();
1381       })
1382       .Case([&](omp::ParallelOp op) {
1383         return convertOmpParallel(op, builder, moduleTranslation);
1384       })
1385       .Case([&](omp::ReductionOp reductionOp) {
1386         return convertOmpReductionOp(reductionOp, builder, moduleTranslation);
1387       })
1388       .Case([&](omp::MasterOp) {
1389         return convertOmpMaster(*op, builder, moduleTranslation);
1390       })
1391       .Case([&](omp::CriticalOp) {
1392         return convertOmpCritical(*op, builder, moduleTranslation);
1393       })
1394       .Case([&](omp::OrderedRegionOp) {
1395         return convertOmpOrderedRegion(*op, builder, moduleTranslation);
1396       })
1397       .Case([&](omp::OrderedOp) {
1398         return convertOmpOrdered(*op, builder, moduleTranslation);
1399       })
1400       .Case([&](omp::WsLoopOp) {
1401         return convertOmpWsLoop(*op, builder, moduleTranslation);
1402       })
1403       .Case([&](omp::SimdLoopOp) {
1404         return convertOmpSimdLoop(*op, builder, moduleTranslation);
1405       })
1406       .Case([&](omp::AtomicReadOp) {
1407         return convertOmpAtomicRead(*op, builder, moduleTranslation);
1408       })
1409       .Case([&](omp::AtomicWriteOp) {
1410         return convertOmpAtomicWrite(*op, builder, moduleTranslation);
1411       })
1412       .Case([&](omp::AtomicUpdateOp op) {
1413         return convertOmpAtomicUpdate(op, builder, moduleTranslation);
1414       })
1415       .Case([&](omp::AtomicCaptureOp op) {
1416         return convertOmpAtomicCapture(op, builder, moduleTranslation);
1417       })
1418       .Case([&](omp::SectionsOp) {
1419         return convertOmpSections(*op, builder, moduleTranslation);
1420       })
1421       .Case([&](omp::SingleOp op) {
1422         return convertOmpSingle(op, builder, moduleTranslation);
1423       })
1424       .Case<omp::YieldOp, omp::TerminatorOp, omp::ReductionDeclareOp,
1425             omp::CriticalDeclareOp>([](auto op) {
1426         // `yield` and `terminator` can be just omitted. The block structure
1427         // was created in the region that handles their parent operation.
1428         // `reduction.declare` will be used by reductions and is not
1429         // converted directly, skip it.
1430         // `critical.declare` is only used to declare names of critical
1431         // sections which will be used by `critical` ops and hence can be
1432         // ignored for lowering. The OpenMP IRBuilder will create unique
1433         // name for critical section names.
1434         return success();
1435       })
1436       .Default([&](Operation *inst) {
1437         return inst->emitError("unsupported OpenMP operation: ")
1438                << inst->getName();
1439       });
1440 }
1441 
1442 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
1443   registry.insert<omp::OpenMPDialect>();
1444   registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
1445     dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
1446   });
1447 }
1448 
1449 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
1450   DialectRegistry registry;
1451   registerOpenMPDialectTranslation(registry);
1452   context.appendDialectRegistry(registry);
1453 }
1454