1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/BlockAndValueMapping.h"
16 #include "mlir/IR/Operation.h"
17 #include "mlir/Support/LLVM.h"
18 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
19 
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
23 #include "llvm/IR/DebugInfoMetadata.h"
24 #include "llvm/IR/IRBuilder.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 static llvm::omp::ScheduleKind
30 convertToScheduleKind(Optional<omp::ClauseScheduleKind> schedKind) {
31   if (!schedKind.hasValue())
32     return llvm::omp::OMP_SCHEDULE_Default;
33   switch (schedKind.getValue()) {
34   case omp::ClauseScheduleKind::Static:
35     return llvm::omp::OMP_SCHEDULE_Static;
36   case omp::ClauseScheduleKind::Dynamic:
37     return llvm::omp::OMP_SCHEDULE_Dynamic;
38   case omp::ClauseScheduleKind::Guided:
39     return llvm::omp::OMP_SCHEDULE_Guided;
40   case omp::ClauseScheduleKind::Auto:
41     return llvm::omp::OMP_SCHEDULE_Auto;
42   case omp::ClauseScheduleKind::Runtime:
43     return llvm::omp::OMP_SCHEDULE_Runtime;
44   }
45   llvm_unreachable("unhandled schedule clause argument");
46 }
47 
48 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
49 /// insertion points for allocas.
50 class OpenMPAllocaStackFrame
51     : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
52 public:
53   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
54 
55   explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
56       : allocaInsertPoint(allocaIP) {}
57   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
58 };
59 
60 /// ModuleTranslation stack frame containing the partial mapping between MLIR
61 /// values and their LLVM IR equivalents.
62 class OpenMPVarMappingStackFrame
63     : public LLVM::ModuleTranslation::StackFrameBase<
64           OpenMPVarMappingStackFrame> {
65 public:
66   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPVarMappingStackFrame)
67 
68   explicit OpenMPVarMappingStackFrame(
69       const DenseMap<Value, llvm::Value *> &mapping)
70       : mapping(mapping) {}
71 
72   DenseMap<Value, llvm::Value *> mapping;
73 };
74 } // namespace
75 
76 /// Find the insertion point for allocas given the current insertion point for
77 /// normal operations in the builder.
78 static llvm::OpenMPIRBuilder::InsertPointTy
79 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
80                       const LLVM::ModuleTranslation &moduleTranslation) {
81   // If there is an alloca insertion point on stack, i.e. we are in a nested
82   // operation and a specific point was provided by some surrounding operation,
83   // use it.
84   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
85   WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
86       [&](const OpenMPAllocaStackFrame &frame) {
87         allocaInsertPoint = frame.allocaInsertPoint;
88         return WalkResult::interrupt();
89       });
90   if (walkResult.wasInterrupted())
91     return allocaInsertPoint;
92 
93   // Otherwise, insert to the entry block of the surrounding function.
94   // If the current IRBuilder InsertPoint is the function's entry, it cannot
95   // also be used for alloca insertion which would result in insertion order
96   // confusion. Create a new BasicBlock for the Builder and use the entry block
97   // for the allocs.
98   if (builder.GetInsertBlock() ==
99       &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
100     assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
101            "Assuming end of basic block");
102     llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
103         builder.getContext(), "entry", builder.GetInsertBlock()->getParent(),
104         builder.GetInsertBlock()->getNextNode());
105     builder.CreateBr(entryBB);
106     builder.SetInsertPoint(entryBB);
107   }
108 
109   llvm::BasicBlock &funcEntryBlock =
110       builder.GetInsertBlock()->getParent()->getEntryBlock();
111   return llvm::OpenMPIRBuilder::InsertPointTy(
112       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
113 }
114 
115 /// Converts the given region that appears within an OpenMP dialect operation to
116 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
117 /// region, and a branch from any block with an successor-less OpenMP terminator
118 /// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
119 /// of the continuation block if provided.
120 static void convertOmpOpRegions(
121     Region &region, StringRef blockName, llvm::BasicBlock &sourceBlock,
122     llvm::BasicBlock &continuationBlock, llvm::IRBuilderBase &builder,
123     LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus,
124     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
125   llvm::LLVMContext &llvmContext = builder.getContext();
126   for (Block &bb : region) {
127     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
128         llvmContext, blockName, builder.GetInsertBlock()->getParent(),
129         builder.GetInsertBlock()->getNextNode());
130     moduleTranslation.mapBlock(&bb, llvmBB);
131   }
132 
133   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
134 
135   // Terminators (namely YieldOp) may be forwarding values to the region that
136   // need to be available in the continuation block. Collect the types of these
137   // operands in preparation of creating PHI nodes.
138   SmallVector<llvm::Type *> continuationBlockPHITypes;
139   bool operandsProcessed = false;
140   unsigned numYields = 0;
141   for (Block &bb : region.getBlocks()) {
142     if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
143       if (!operandsProcessed) {
144         for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
145           continuationBlockPHITypes.push_back(
146               moduleTranslation.convertType(yield->getOperand(i).getType()));
147         }
148         operandsProcessed = true;
149       } else {
150         assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
151                "mismatching number of values yielded from the region");
152         for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
153           llvm::Type *operandType =
154               moduleTranslation.convertType(yield->getOperand(i).getType());
155           (void)operandType;
156           assert(continuationBlockPHITypes[i] == operandType &&
157                  "values of mismatching types yielded from the region");
158         }
159       }
160       numYields++;
161     }
162   }
163 
164   // Insert PHI nodes in the continuation block for any values forwarded by the
165   // terminators in this region.
166   if (!continuationBlockPHITypes.empty())
167     assert(
168         continuationBlockPHIs &&
169         "expected continuation block PHIs if converted regions yield values");
170   if (continuationBlockPHIs) {
171     llvm::IRBuilderBase::InsertPointGuard guard(builder);
172     continuationBlockPHIs->reserve(continuationBlockPHITypes.size());
173     builder.SetInsertPoint(&continuationBlock, continuationBlock.begin());
174     for (llvm::Type *ty : continuationBlockPHITypes)
175       continuationBlockPHIs->push_back(builder.CreatePHI(ty, numYields));
176   }
177 
178   // Convert blocks one by one in topological order to ensure
179   // defs are converted before uses.
180   SetVector<Block *> blocks =
181       LLVM::detail::getTopologicallySortedBlocks(region);
182   for (Block *bb : blocks) {
183     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
184     // Retarget the branch of the entry block to the entry block of the
185     // converted region (regions are single-entry).
186     if (bb->isEntryBlock()) {
187       assert(sourceTerminator->getNumSuccessors() == 1 &&
188              "provided entry block has multiple successors");
189       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
190              "ContinuationBlock is not the successor of the entry block");
191       sourceTerminator->setSuccessor(0, llvmBB);
192     }
193 
194     llvm::IRBuilderBase::InsertPointGuard guard(builder);
195     if (failed(
196             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
197       bodyGenStatus = failure();
198       return;
199     }
200 
201     // Special handling for `omp.yield` and `omp.terminator` (we may have more
202     // than one): they return the control to the parent OpenMP dialect operation
203     // so replace them with the branch to the continuation block. We handle this
204     // here to avoid relying inter-function communication through the
205     // ModuleTranslation class to set up the correct insertion point. This is
206     // also consistent with MLIR's idiom of handling special region terminators
207     // in the same code that handles the region-owning operation.
208     Operation *terminator = bb->getTerminator();
209     if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
210       builder.CreateBr(&continuationBlock);
211 
212       for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
213         (*continuationBlockPHIs)[i]->addIncoming(
214             moduleTranslation.lookupValue(terminator->getOperand(i)), llvmBB);
215     }
216   }
217   // After all blocks have been traversed and values mapped, connect the PHI
218   // nodes to the results of preceding blocks.
219   LLVM::detail::connectPHINodes(region, moduleTranslation);
220 
221   // Remove the blocks and values defined in this region from the mapping since
222   // they are not visible outside of this region. This allows the same region to
223   // be converted several times, that is cloned, without clashes, and slightly
224   // speeds up the lookups.
225   moduleTranslation.forgetMapping(region);
226 }
227 
228 /// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
229 static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
230   switch (kind) {
231   case omp::ClauseProcBindKind::Close:
232     return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
233   case omp::ClauseProcBindKind::Master:
234     return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
235   case omp::ClauseProcBindKind::Primary:
236     return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
237   case omp::ClauseProcBindKind::Spread:
238     return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
239   }
240   llvm_unreachable("Unknown ClauseProcBindKind kind");
241 }
242 
243 /// Converts the OpenMP parallel operation to LLVM IR.
244 static LogicalResult
245 convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
246                    LLVM::ModuleTranslation &moduleTranslation) {
247   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
248   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
249   // relying on captured variables.
250   LogicalResult bodyGenStatus = success();
251 
252   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
253                        llvm::BasicBlock &continuationBlock) {
254     // Save the alloca insertion point on ModuleTranslation stack for use in
255     // nested regions.
256     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
257         moduleTranslation, allocaIP);
258 
259     // ParallelOp has only one region associated with it.
260     convertOmpOpRegions(opInst.getRegion(), "omp.par.region",
261                         *codeGenIP.getBlock(), continuationBlock, builder,
262                         moduleTranslation, bodyGenStatus);
263   };
264 
265   // TODO: Perform appropriate actions according to the data-sharing
266   // attribute (shared, private, firstprivate, ...) of variables.
267   // Currently defaults to shared.
268   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
269                     llvm::Value &, llvm::Value &vPtr,
270                     llvm::Value *&replacementValue) -> InsertPointTy {
271     replacementValue = &vPtr;
272 
273     return codeGenIP;
274   };
275 
276   // TODO: Perform finalization actions for variables. This has to be
277   // called for variables which have destructors/finalizers.
278   auto finiCB = [&](InsertPointTy codeGenIP) {};
279 
280   llvm::Value *ifCond = nullptr;
281   if (auto ifExprVar = opInst.if_expr_var())
282     ifCond = moduleTranslation.lookupValue(ifExprVar);
283   llvm::Value *numThreads = nullptr;
284   if (auto numThreadsVar = opInst.num_threads_var())
285     numThreads = moduleTranslation.lookupValue(numThreadsVar);
286   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
287   if (auto bind = opInst.proc_bind_val())
288     pbKind = getProcBindKind(*bind);
289   // TODO: Is the Parallel construct cancellable?
290   bool isCancellable = false;
291 
292   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
293       findAllocaInsertPoint(builder, moduleTranslation);
294   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
295   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
296       ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
297       isCancellable));
298 
299   return bodyGenStatus;
300 }
301 
302 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
303 static LogicalResult
304 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
305                  LLVM::ModuleTranslation &moduleTranslation) {
306   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
307   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
308   // relying on captured variables.
309   LogicalResult bodyGenStatus = success();
310 
311   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
312                        llvm::BasicBlock &continuationBlock) {
313     // MasterOp has only one region associated with it.
314     auto &region = cast<omp::MasterOp>(opInst).getRegion();
315     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
316                         continuationBlock, builder, moduleTranslation,
317                         bodyGenStatus);
318   };
319 
320   // TODO: Perform finalization actions for variables. This has to be
321   // called for variables which have destructors/finalizers.
322   auto finiCB = [&](InsertPointTy codeGenIP) {};
323 
324   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
325   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
326       ompLoc, bodyGenCB, finiCB));
327   return success();
328 }
329 
330 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
331 static LogicalResult
332 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
333                    LLVM::ModuleTranslation &moduleTranslation) {
334   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
335   auto criticalOp = cast<omp::CriticalOp>(opInst);
336   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
337   // relying on captured variables.
338   LogicalResult bodyGenStatus = success();
339 
340   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
341                        llvm::BasicBlock &continuationBlock) {
342     // CriticalOp has only one region associated with it.
343     auto &region = cast<omp::CriticalOp>(opInst).getRegion();
344     convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(),
345                         continuationBlock, builder, moduleTranslation,
346                         bodyGenStatus);
347   };
348 
349   // TODO: Perform finalization actions for variables. This has to be
350   // called for variables which have destructors/finalizers.
351   auto finiCB = [&](InsertPointTy codeGenIP) {};
352 
353   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
354   llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
355   llvm::Constant *hint = nullptr;
356 
357   // If it has a name, it probably has a hint too.
358   if (criticalOp.nameAttr()) {
359     // The verifiers in OpenMP Dialect guarentee that all the pointers are
360     // non-null
361     auto symbolRef = criticalOp.nameAttr().cast<SymbolRefAttr>();
362     auto criticalDeclareOp =
363         SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
364                                                                      symbolRef);
365     hint =
366         llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
367                                static_cast<int>(criticalDeclareOp.hint_val()));
368   }
369   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
370       ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
371   return success();
372 }
373 
374 /// Returns a reduction declaration that corresponds to the given reduction
375 /// operation in the given container. Currently only supports reductions inside
376 /// WsLoopOp but can be easily extended.
377 static omp::ReductionDeclareOp findReductionDecl(omp::WsLoopOp container,
378                                                  omp::ReductionOp reduction) {
379   SymbolRefAttr reductionSymbol;
380   for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) {
381     if (container.reduction_vars()[i] != reduction.accumulator())
382       continue;
383     reductionSymbol = (*container.reductions())[i].cast<SymbolRefAttr>();
384     break;
385   }
386   assert(reductionSymbol &&
387          "reduction operation must be associated with a declaration");
388 
389   return SymbolTable::lookupNearestSymbolFrom<omp::ReductionDeclareOp>(
390       container, reductionSymbol);
391 }
392 
393 /// Populates `reductions` with reduction declarations used in the given loop.
394 static void
395 collectReductionDecls(omp::WsLoopOp loop,
396                       SmallVectorImpl<omp::ReductionDeclareOp> &reductions) {
397   Optional<ArrayAttr> attr = loop.reductions();
398   if (!attr)
399     return;
400 
401   reductions.reserve(reductions.size() + loop.getNumReductionVars());
402   for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
403     reductions.push_back(
404         SymbolTable::lookupNearestSymbolFrom<omp::ReductionDeclareOp>(
405             loop, symbolRef));
406   }
407 }
408 
409 /// Translates the blocks contained in the given region and appends them to at
410 /// the current insertion point of `builder`. The operations of the entry block
411 /// are appended to the current insertion block, which is not expected to have a
412 /// terminator. If set, `continuationBlockArgs` is populated with translated
413 /// values that correspond to the values omp.yield'ed from the region.
414 static LogicalResult inlineConvertOmpRegions(
415     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
416     LLVM::ModuleTranslation &moduleTranslation,
417     SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
418   if (region.empty())
419     return success();
420 
421   // Special case for single-block regions that don't create additional blocks:
422   // insert operations without creating additional blocks.
423   if (llvm::hasSingleElement(region)) {
424     moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
425     if (failed(moduleTranslation.convertBlock(
426             region.front(), /*ignoreArguments=*/true, builder)))
427       return failure();
428 
429     // The continuation arguments are simply the translated terminator operands.
430     if (continuationBlockArgs)
431       llvm::append_range(
432           *continuationBlockArgs,
433           moduleTranslation.lookupValues(region.front().back().getOperands()));
434 
435     // Drop the mapping that is no longer necessary so that the same region can
436     // be processed multiple times.
437     moduleTranslation.forgetMapping(region);
438     return success();
439   }
440 
441   // Create the continuation block manually instead of calling splitBlock
442   // because the current insertion block may not have a terminator.
443   llvm::BasicBlock *continuationBlock =
444       llvm::BasicBlock::Create(builder.getContext(), blockName + ".cont",
445                                builder.GetInsertBlock()->getParent(),
446                                builder.GetInsertBlock()->getNextNode());
447   builder.CreateBr(continuationBlock);
448 
449   LogicalResult bodyGenStatus = success();
450   SmallVector<llvm::PHINode *> phis;
451   convertOmpOpRegions(region, blockName, *builder.GetInsertBlock(),
452                       *continuationBlock, builder, moduleTranslation,
453                       bodyGenStatus, &phis);
454   if (failed(bodyGenStatus))
455     return failure();
456   if (continuationBlockArgs)
457     llvm::append_range(*continuationBlockArgs, phis);
458   builder.SetInsertPoint(continuationBlock,
459                          continuationBlock->getFirstInsertionPt());
460   return success();
461 }
462 
463 namespace {
464 /// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
465 /// store lambdas with capture.
466 using OwningReductionGen = std::function<llvm::OpenMPIRBuilder::InsertPointTy(
467     llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
468     llvm::Value *&)>;
469 using OwningAtomicReductionGen =
470     std::function<llvm::OpenMPIRBuilder::InsertPointTy(
471         llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
472         llvm::Value *)>;
473 } // namespace
474 
475 /// Create an OpenMPIRBuilder-compatible reduction generator for the given
476 /// reduction declaration. The generator uses `builder` but ignores its
477 /// insertion point.
478 static OwningReductionGen
479 makeReductionGen(omp::ReductionDeclareOp decl, llvm::IRBuilderBase &builder,
480                  LLVM::ModuleTranslation &moduleTranslation) {
481   // The lambda is mutable because we need access to non-const methods of decl
482   // (which aren't actually mutating it), and we must capture decl by-value to
483   // avoid the dangling reference after the parent function returns.
484   OwningReductionGen gen =
485       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
486                 llvm::Value *lhs, llvm::Value *rhs,
487                 llvm::Value *&result) mutable {
488         Region &reductionRegion = decl.reductionRegion();
489         moduleTranslation.mapValue(reductionRegion.front().getArgument(0), lhs);
490         moduleTranslation.mapValue(reductionRegion.front().getArgument(1), rhs);
491         builder.restoreIP(insertPoint);
492         SmallVector<llvm::Value *> phis;
493         if (failed(inlineConvertOmpRegions(reductionRegion,
494                                            "omp.reduction.nonatomic.body",
495                                            builder, moduleTranslation, &phis)))
496           return llvm::OpenMPIRBuilder::InsertPointTy();
497         assert(phis.size() == 1);
498         result = phis[0];
499         return builder.saveIP();
500       };
501   return gen;
502 }
503 
504 /// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
505 /// given reduction declaration. The generator uses `builder` but ignores its
506 /// insertion point. Returns null if there is no atomic region available in the
507 /// reduction declaration.
508 static OwningAtomicReductionGen
509 makeAtomicReductionGen(omp::ReductionDeclareOp decl,
510                        llvm::IRBuilderBase &builder,
511                        LLVM::ModuleTranslation &moduleTranslation) {
512   if (decl.atomicReductionRegion().empty())
513     return OwningAtomicReductionGen();
514 
515   // The lambda is mutable because we need access to non-const methods of decl
516   // (which aren't actually mutating it), and we must capture decl by-value to
517   // avoid the dangling reference after the parent function returns.
518   OwningAtomicReductionGen atomicGen =
519       [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
520                 llvm::Value *lhs, llvm::Value *rhs) mutable {
521         Region &atomicRegion = decl.atomicReductionRegion();
522         moduleTranslation.mapValue(atomicRegion.front().getArgument(0), lhs);
523         moduleTranslation.mapValue(atomicRegion.front().getArgument(1), rhs);
524         builder.restoreIP(insertPoint);
525         SmallVector<llvm::Value *> phis;
526         if (failed(inlineConvertOmpRegions(atomicRegion,
527                                            "omp.reduction.atomic.body", builder,
528                                            moduleTranslation, &phis)))
529           return llvm::OpenMPIRBuilder::InsertPointTy();
530         assert(phis.empty());
531         return builder.saveIP();
532       };
533   return atomicGen;
534 }
535 
536 /// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
537 static LogicalResult
538 convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
539                   LLVM::ModuleTranslation &moduleTranslation) {
540   auto orderedOp = cast<omp::OrderedOp>(opInst);
541 
542   omp::ClauseDepend dependType = *orderedOp.depend_type_val();
543   bool isDependSource = dependType == omp::ClauseDepend::dependsource;
544   unsigned numLoops = orderedOp.num_loops_val().getValue();
545   SmallVector<llvm::Value *> vecValues =
546       moduleTranslation.lookupValues(orderedOp.depend_vec_vars());
547 
548   size_t indexVecValues = 0;
549   while (indexVecValues < vecValues.size()) {
550     SmallVector<llvm::Value *> storeValues;
551     storeValues.reserve(numLoops);
552     for (unsigned i = 0; i < numLoops; i++) {
553       storeValues.push_back(vecValues[indexVecValues]);
554       indexVecValues++;
555     }
556     llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
557         findAllocaInsertPoint(builder, moduleTranslation);
558     llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
559     builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
560         ompLoc, allocaIP, numLoops, storeValues, ".cnt.addr", isDependSource));
561   }
562   return success();
563 }
564 
565 /// Converts an OpenMP 'ordered_region' operation into LLVM IR using
566 /// OpenMPIRBuilder.
567 static LogicalResult
568 convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
569                         LLVM::ModuleTranslation &moduleTranslation) {
570   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
571   auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
572 
573   // TODO: The code generation for ordered simd directive is not supported yet.
574   if (orderedRegionOp.simd())
575     return failure();
576 
577   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
578   // relying on captured variables.
579   LogicalResult bodyGenStatus = success();
580 
581   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
582                        llvm::BasicBlock &continuationBlock) {
583     // OrderedOp has only one region associated with it.
584     auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
585     convertOmpOpRegions(region, "omp.ordered.region", *codeGenIP.getBlock(),
586                         continuationBlock, builder, moduleTranslation,
587                         bodyGenStatus);
588   };
589 
590   // TODO: Perform finalization actions for variables. This has to be
591   // called for variables which have destructors/finalizers.
592   auto finiCB = [&](InsertPointTy codeGenIP) {};
593 
594   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
595   builder.restoreIP(
596       moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
597           ompLoc, bodyGenCB, finiCB, !orderedRegionOp.simd()));
598   return bodyGenStatus;
599 }
600 
601 static LogicalResult
602 convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
603                    LLVM::ModuleTranslation &moduleTranslation) {
604   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
605   using StorableBodyGenCallbackTy =
606       llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
607 
608   auto sectionsOp = cast<omp::SectionsOp>(opInst);
609 
610   // TODO: Support the following clauses: private, firstprivate, lastprivate,
611   // reduction, allocate
612   if (!sectionsOp.reduction_vars().empty() || sectionsOp.reductions() ||
613       !sectionsOp.allocate_vars().empty() ||
614       !sectionsOp.allocators_vars().empty())
615     return emitError(sectionsOp.getLoc())
616            << "reduction and allocate clauses are not supported for sections "
617               "construct";
618 
619   LogicalResult bodyGenStatus = success();
620   SmallVector<StorableBodyGenCallbackTy> sectionCBs;
621 
622   for (Operation &op : *sectionsOp.region().begin()) {
623     auto sectionOp = dyn_cast<omp::SectionOp>(op);
624     if (!sectionOp) // omp.terminator
625       continue;
626 
627     Region &region = sectionOp.region();
628     auto sectionCB = [&region, &builder, &moduleTranslation, &bodyGenStatus](
629                          InsertPointTy allocaIP, InsertPointTy codeGenIP,
630                          llvm::BasicBlock &finiBB) {
631       builder.restoreIP(codeGenIP);
632       builder.CreateBr(&finiBB);
633       convertOmpOpRegions(region, "omp.section.region", *codeGenIP.getBlock(),
634                           finiBB, builder, moduleTranslation, bodyGenStatus);
635     };
636     sectionCBs.push_back(sectionCB);
637   }
638 
639   // No sections within omp.sections operation - skip generation. This situation
640   // is only possible if there is only a terminator operation inside the
641   // sections operation
642   if (sectionCBs.empty())
643     return success();
644 
645   assert(isa<omp::SectionOp>(*sectionsOp.region().op_begin()));
646 
647   // TODO: Perform appropriate actions according to the data-sharing
648   // attribute (shared, private, firstprivate, ...) of variables.
649   // Currently defaults to shared.
650   auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
651                     llvm::Value &vPtr,
652                     llvm::Value *&replacementValue) -> InsertPointTy {
653     replacementValue = &vPtr;
654     return codeGenIP;
655   };
656 
657   // TODO: Perform finalization actions for variables. This has to be
658   // called for variables which have destructors/finalizers.
659   auto finiCB = [&](InsertPointTy codeGenIP) {};
660 
661   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
662       findAllocaInsertPoint(builder, moduleTranslation);
663   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
664   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSections(
665       ompLoc, allocaIP, sectionCBs, privCB, finiCB, false,
666       sectionsOp.nowait()));
667   return bodyGenStatus;
668 }
669 
670 /// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
671 static LogicalResult
672 convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
673                  LLVM::ModuleTranslation &moduleTranslation) {
674   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
675   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
676   LogicalResult bodyGenStatus = success();
677   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP,
678                     llvm::BasicBlock &continuationBB) {
679     convertOmpOpRegions(singleOp.region(), "omp.single.region",
680                         *codegenIP.getBlock(), continuationBB, builder,
681                         moduleTranslation, bodyGenStatus);
682   };
683   auto finiCB = [&](InsertPointTy codeGenIP) {};
684   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createSingle(
685       ompLoc, bodyCB, finiCB, singleOp.nowait(), /*DidIt=*/nullptr));
686   return bodyGenStatus;
687 }
688 
689 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
690 static LogicalResult
691 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
692                  LLVM::ModuleTranslation &moduleTranslation) {
693   auto loop = cast<omp::WsLoopOp>(opInst);
694   // TODO: this should be in the op verifier instead.
695   if (loop.lowerBound().empty())
696     return failure();
697 
698   // Static is the default.
699   auto schedule =
700       loop.schedule_val().getValueOr(omp::ClauseScheduleKind::Static);
701 
702   // Find the loop configuration.
703   llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
704   llvm::Type *ivType = step->getType();
705   llvm::Value *chunk = nullptr;
706   if (loop.schedule_chunk_var()) {
707     llvm::Value *chunkVar =
708         moduleTranslation.lookupValue(loop.schedule_chunk_var());
709     llvm::Type *chunkVarType = chunkVar->getType();
710     assert(chunkVarType->isIntegerTy() &&
711            "chunk size must be one integer expression");
712     if (chunkVarType->getIntegerBitWidth() < ivType->getIntegerBitWidth())
713       chunk = builder.CreateSExt(chunkVar, ivType);
714     else if (chunkVarType->getIntegerBitWidth() > ivType->getIntegerBitWidth())
715       chunk = builder.CreateTrunc(chunkVar, ivType);
716     else
717       chunk = chunkVar;
718   }
719 
720   SmallVector<omp::ReductionDeclareOp> reductionDecls;
721   collectReductionDecls(loop, reductionDecls);
722   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
723       findAllocaInsertPoint(builder, moduleTranslation);
724 
725   // Allocate space for privatized reduction variables.
726   SmallVector<llvm::Value *> privateReductionVariables;
727   DenseMap<Value, llvm::Value *> reductionVariableMap;
728   unsigned numReductions = loop.getNumReductionVars();
729   privateReductionVariables.reserve(numReductions);
730   if (numReductions != 0) {
731     llvm::IRBuilderBase::InsertPointGuard guard(builder);
732     builder.restoreIP(allocaIP);
733     for (unsigned i = 0; i < numReductions; ++i) {
734       auto reductionType =
735           loop.reduction_vars()[i].getType().cast<LLVM::LLVMPointerType>();
736       llvm::Value *var = builder.CreateAlloca(
737           moduleTranslation.convertType(reductionType.getElementType()));
738       privateReductionVariables.push_back(var);
739       reductionVariableMap.try_emplace(loop.reduction_vars()[i], var);
740     }
741   }
742 
743   // Store the mapping between reduction variables and their private copies on
744   // ModuleTranslation stack. It can be then recovered when translating
745   // omp.reduce operations in a separate call.
746   LLVM::ModuleTranslation::SaveStack<OpenMPVarMappingStackFrame> mappingGuard(
747       moduleTranslation, reductionVariableMap);
748 
749   // Before the loop, store the initial values of reductions into reduction
750   // variables. Although this could be done after allocas, we don't want to mess
751   // up with the alloca insertion point.
752   for (unsigned i = 0; i < numReductions; ++i) {
753     SmallVector<llvm::Value *> phis;
754     if (failed(inlineConvertOmpRegions(reductionDecls[i].initializerRegion(),
755                                        "omp.reduction.neutral", builder,
756                                        moduleTranslation, &phis)))
757       return failure();
758     assert(phis.size() == 1 && "expected one value to be yielded from the "
759                                "reduction neutral element declaration region");
760     builder.CreateStore(phis[0], privateReductionVariables[i]);
761   }
762 
763   // Set up the source location value for OpenMP runtime.
764   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
765 
766   // Generator of the canonical loop body.
767   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
768   // relying on captured variables.
769   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
770   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
771   LogicalResult bodyGenStatus = success();
772   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
773     // Make sure further conversions know about the induction variable.
774     moduleTranslation.mapValue(
775         loop.getRegion().front().getArgument(loopInfos.size()), iv);
776 
777     // Capture the body insertion point for use in nested loops. BodyIP of the
778     // CanonicalLoopInfo always points to the beginning of the entry block of
779     // the body.
780     bodyInsertPoints.push_back(ip);
781 
782     if (loopInfos.size() != loop.getNumLoops() - 1)
783       return;
784 
785     // Convert the body of the loop.
786     llvm::BasicBlock *entryBlock = ip.getBlock();
787     llvm::BasicBlock *exitBlock =
788         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
789     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
790                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
791   };
792 
793   // Delegate actual loop construction to the OpenMP IRBuilder.
794   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
795   // i.e. it has a positive step, uses signed integer semantics. Reconsider
796   // this code when WsLoop clearly supports more cases.
797   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
798   for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
799     llvm::Value *lowerBound =
800         moduleTranslation.lookupValue(loop.lowerBound()[i]);
801     llvm::Value *upperBound =
802         moduleTranslation.lookupValue(loop.upperBound()[i]);
803     llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
804 
805     // Make sure loop trip count are emitted in the preheader of the outermost
806     // loop at the latest so that they are all available for the new collapsed
807     // loop will be created below.
808     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
809     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
810     if (i != 0) {
811       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
812       computeIP = loopInfos.front()->getPreheaderIP();
813     }
814     loopInfos.push_back(ompBuilder->createCanonicalLoop(
815         loc, bodyGen, lowerBound, upperBound, step,
816         /*IsSigned=*/true, loop.inclusive(), computeIP));
817 
818     if (failed(bodyGenStatus))
819       return failure();
820   }
821 
822   // Collapse loops. Store the insertion point because LoopInfos may get
823   // invalidated.
824   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
825   llvm::CanonicalLoopInfo *loopInfo =
826       ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
827 
828   allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
829 
830   // TODO: Handle doacross loops when the ordered clause has a parameter.
831   bool isOrdered = loop.ordered_val().hasValue();
832   Optional<omp::ScheduleModifier> scheduleModifier = loop.schedule_modifier();
833   bool isSimd = loop.simd_modifier();
834 
835   ompBuilder->applyWorkshareLoop(
836       ompLoc.DL, loopInfo, allocaIP, !loop.nowait(),
837       convertToScheduleKind(schedule), chunk, isSimd,
838       scheduleModifier == omp::ScheduleModifier::monotonic,
839       scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
840 
841   // Continue building IR after the loop. Note that the LoopInfo returned by
842   // `collapseLoops` points inside the outermost loop and is intended for
843   // potential further loop transformations. Use the insertion point stored
844   // before collapsing loops instead.
845   builder.restoreIP(afterIP);
846 
847   // Process the reductions if required.
848   if (numReductions == 0)
849     return success();
850 
851   // Create the reduction generators. We need to own them here because
852   // ReductionInfo only accepts references to the generators.
853   SmallVector<OwningReductionGen> owningReductionGens;
854   SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
855   for (unsigned i = 0; i < numReductions; ++i) {
856     owningReductionGens.push_back(
857         makeReductionGen(reductionDecls[i], builder, moduleTranslation));
858     owningAtomicReductionGens.push_back(
859         makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
860   }
861 
862   // Collect the reduction information.
863   SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
864   reductionInfos.reserve(numReductions);
865   for (unsigned i = 0; i < numReductions; ++i) {
866     llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
867     if (owningAtomicReductionGens[i])
868       atomicGen = owningAtomicReductionGens[i];
869     auto reductionType =
870         loop.reduction_vars()[i].getType().cast<LLVM::LLVMPointerType>();
871     llvm::Value *variable =
872         moduleTranslation.lookupValue(loop.reduction_vars()[i]);
873     reductionInfos.push_back(
874         {moduleTranslation.convertType(reductionType.getElementType()),
875          variable, privateReductionVariables[i], owningReductionGens[i],
876          atomicGen});
877   }
878 
879   // The call to createReductions below expects the block to have a
880   // terminator. Create an unreachable instruction to serve as terminator
881   // and remove it later.
882   llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
883   builder.SetInsertPoint(tempTerminator);
884   llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
885       ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
886                                    loop.nowait());
887   if (!contInsertPoint.getBlock())
888     return loop->emitOpError() << "failed to convert reductions";
889   auto nextInsertionPoint =
890       ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
891   tempTerminator->eraseFromParent();
892   builder.restoreIP(nextInsertionPoint);
893 
894   return success();
895 }
896 
897 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
898 static LogicalResult
899 convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
900                    LLVM::ModuleTranslation &moduleTranslation) {
901   auto loop = cast<omp::SimdLoopOp>(opInst);
902 
903   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
904 
905   // Generator of the canonical loop body.
906   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
907   // relying on captured variables.
908   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
909   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
910   LogicalResult bodyGenStatus = success();
911   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
912     // Make sure further conversions know about the induction variable.
913     moduleTranslation.mapValue(
914         loop.getRegion().front().getArgument(loopInfos.size()), iv);
915 
916     // Capture the body insertion point for use in nested loops. BodyIP of the
917     // CanonicalLoopInfo always points to the beginning of the entry block of
918     // the body.
919     bodyInsertPoints.push_back(ip);
920 
921     if (loopInfos.size() != loop.getNumLoops() - 1)
922       return;
923 
924     // Convert the body of the loop.
925     llvm::BasicBlock *entryBlock = ip.getBlock();
926     llvm::BasicBlock *exitBlock =
927         entryBlock->splitBasicBlock(ip.getPoint(), "omp.simdloop.exit");
928     convertOmpOpRegions(loop.region(), "omp.simdloop.region", *entryBlock,
929                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
930   };
931 
932   // Delegate actual loop construction to the OpenMP IRBuilder.
933   // TODO: this currently assumes SimdLoop is semantically similar to SCF loop,
934   // i.e. it has a positive step, uses signed integer semantics. Reconsider
935   // this code when SimdLoop clearly supports more cases.
936   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
937   for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
938     llvm::Value *lowerBound =
939         moduleTranslation.lookupValue(loop.lowerBound()[i]);
940     llvm::Value *upperBound =
941         moduleTranslation.lookupValue(loop.upperBound()[i]);
942     llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
943 
944     // Make sure loop trip count are emitted in the preheader of the outermost
945     // loop at the latest so that they are all available for the new collapsed
946     // loop will be created below.
947     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
948     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
949     if (i != 0) {
950       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
951                                                        ompLoc.DL);
952       computeIP = loopInfos.front()->getPreheaderIP();
953     }
954     loopInfos.push_back(ompBuilder->createCanonicalLoop(
955         loc, bodyGen, lowerBound, upperBound, step,
956         /*IsSigned=*/true, /*Inclusive=*/true, computeIP));
957 
958     if (failed(bodyGenStatus))
959       return failure();
960   }
961 
962   // Collapse loops.
963   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
964   llvm::CanonicalLoopInfo *loopInfo =
965       ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
966 
967   ompBuilder->applySimd(ompLoc.DL, loopInfo);
968 
969   builder.restoreIP(afterIP);
970   return success();
971 }
972 
973 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
974 llvm::AtomicOrdering
975 convertAtomicOrdering(Optional<omp::ClauseMemoryOrderKind> ao) {
976   if (!ao)
977     return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
978 
979   switch (*ao) {
980   case omp::ClauseMemoryOrderKind::Seq_cst:
981     return llvm::AtomicOrdering::SequentiallyConsistent;
982   case omp::ClauseMemoryOrderKind::Acq_rel:
983     return llvm::AtomicOrdering::AcquireRelease;
984   case omp::ClauseMemoryOrderKind::Acquire:
985     return llvm::AtomicOrdering::Acquire;
986   case omp::ClauseMemoryOrderKind::Release:
987     return llvm::AtomicOrdering::Release;
988   case omp::ClauseMemoryOrderKind::Relaxed:
989     return llvm::AtomicOrdering::Monotonic;
990   }
991   llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
992 }
993 
994 /// Convert omp.atomic.read operation to LLVM IR.
995 static LogicalResult
996 convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
997                      LLVM::ModuleTranslation &moduleTranslation) {
998 
999   auto readOp = cast<omp::AtomicReadOp>(opInst);
1000   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1001 
1002   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1003 
1004   llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.memory_order_val());
1005   llvm::Value *x = moduleTranslation.lookupValue(readOp.x());
1006   Type xTy = readOp.x().getType().cast<omp::PointerLikeType>().getElementType();
1007   llvm::Value *v = moduleTranslation.lookupValue(readOp.v());
1008   Type vTy = readOp.v().getType().cast<omp::PointerLikeType>().getElementType();
1009   llvm::OpenMPIRBuilder::AtomicOpValue V = {
1010       v, moduleTranslation.convertType(vTy), false, false};
1011   llvm::OpenMPIRBuilder::AtomicOpValue X = {
1012       x, moduleTranslation.convertType(xTy), false, false};
1013   builder.restoreIP(ompBuilder->createAtomicRead(ompLoc, X, V, AO));
1014   return success();
1015 }
1016 
1017 /// Converts an omp.atomic.write operation to LLVM IR.
1018 static LogicalResult
1019 convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
1020                       LLVM::ModuleTranslation &moduleTranslation) {
1021   auto writeOp = cast<omp::AtomicWriteOp>(opInst);
1022   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1023 
1024   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1025   llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.memory_order_val());
1026   llvm::Value *expr = moduleTranslation.lookupValue(writeOp.value());
1027   llvm::Value *dest = moduleTranslation.lookupValue(writeOp.address());
1028   llvm::Type *ty = moduleTranslation.convertType(writeOp.value().getType());
1029   llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
1030                                             /*isVolatile=*/false};
1031   builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
1032   return success();
1033 }
1034 
1035 /// Converts an LLVM dialect binary operation to the corresponding enum value
1036 /// for `atomicrmw` supported binary operation.
1037 llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
1038   return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
1039       .Case([&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
1040       .Case([&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
1041       .Case([&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
1042       .Case([&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
1043       .Case([&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
1044       .Case([&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
1045       .Case([&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
1046       .Case([&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
1047       .Case([&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
1048       .Default(llvm::AtomicRMWInst::BinOp::BAD_BINOP);
1049 }
1050 
1051 /// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
1052 static LogicalResult
1053 convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
1054                        llvm::IRBuilderBase &builder,
1055                        LLVM::ModuleTranslation &moduleTranslation) {
1056   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1057 
1058   // Convert values and types.
1059   auto &innerOpList = opInst.region().front().getOperations();
1060   if (innerOpList.size() != 2)
1061     return opInst.emitError("exactly two operations are allowed inside an "
1062                             "atomic update region while lowering to LLVM IR");
1063 
1064   Operation &innerUpdateOp = innerOpList.front();
1065 
1066   if (innerUpdateOp.getNumOperands() != 2 ||
1067       !llvm::is_contained(innerUpdateOp.getOperands(),
1068                           opInst.getRegion().getArgument(0)))
1069     return opInst.emitError(
1070         "the update operation inside the region must be a binary operation and "
1071         "that update operation must have the region argument as an operand");
1072 
1073   llvm::AtomicRMWInst::BinOp binop = convertBinOpToAtomic(innerUpdateOp);
1074 
1075   bool isXBinopExpr =
1076       innerUpdateOp.getNumOperands() > 0 &&
1077       innerUpdateOp.getOperand(0) == opInst.getRegion().getArgument(0);
1078 
1079   mlir::Value mlirExpr = (isXBinopExpr ? innerUpdateOp.getOperand(1)
1080                                        : innerUpdateOp.getOperand(0));
1081   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
1082   llvm::Value *llvmX = moduleTranslation.lookupValue(opInst.x());
1083   LLVM::LLVMPointerType mlirXType =
1084       opInst.x().getType().cast<LLVM::LLVMPointerType>();
1085   llvm::Type *llvmXElementType =
1086       moduleTranslation.convertType(mlirXType.getElementType());
1087   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1088                                                       /*isSigned=*/false,
1089                                                       /*isVolatile=*/false};
1090 
1091   llvm::AtomicOrdering atomicOrdering =
1092       convertAtomicOrdering(opInst.memory_order_val());
1093 
1094   // Generate update code.
1095   LogicalResult updateGenStatus = success();
1096   auto updateFn = [&opInst, &moduleTranslation, &updateGenStatus](
1097                       llvm::Value *atomicx,
1098                       llvm::IRBuilder<> &builder) -> llvm::Value * {
1099     Block &bb = *opInst.region().begin();
1100     moduleTranslation.mapValue(*opInst.region().args_begin(), atomicx);
1101     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
1102     if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
1103       updateGenStatus = (opInst.emitError()
1104                          << "unable to convert update operation to llvm IR");
1105       return nullptr;
1106     }
1107     omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
1108     assert(yieldop && yieldop.results().size() == 1 &&
1109            "terminator must be omp.yield op and it must have exactly one "
1110            "argument");
1111     return moduleTranslation.lookupValue(yieldop.results()[0]);
1112   };
1113 
1114   // Handle ambiguous alloca, if any.
1115   auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1116   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1117   builder.restoreIP(ompBuilder->createAtomicUpdate(
1118       ompLoc, allocaIP, llvmAtomicX, llvmExpr, atomicOrdering, binop, updateFn,
1119       isXBinopExpr));
1120   return updateGenStatus;
1121 }
1122 
1123 static LogicalResult
1124 convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
1125                         llvm::IRBuilderBase &builder,
1126                         LLVM::ModuleTranslation &moduleTranslation) {
1127   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1128   mlir::Value mlirExpr;
1129   bool isXBinopExpr = false, isPostfixUpdate = false;
1130   llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
1131 
1132   omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
1133   omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
1134 
1135   assert((atomicUpdateOp || atomicWriteOp) &&
1136          "internal op must be an atomic.update or atomic.write op");
1137 
1138   if (atomicWriteOp) {
1139     isPostfixUpdate = true;
1140     mlirExpr = atomicWriteOp.value();
1141   } else {
1142     isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
1143                       atomicCaptureOp.getAtomicUpdateOp().getOperation();
1144     auto &innerOpList = atomicUpdateOp.region().front().getOperations();
1145     if (innerOpList.size() != 2)
1146       return atomicUpdateOp.emitError(
1147           "exactly two operations are allowed inside an "
1148           "atomic update region while lowering to LLVM IR");
1149     Operation *innerUpdateOp = atomicUpdateOp.getFirstOp();
1150     if (innerUpdateOp->getNumOperands() != 2 ||
1151         !llvm::is_contained(innerUpdateOp->getOperands(),
1152                             atomicUpdateOp.getRegion().getArgument(0)))
1153       return atomicUpdateOp.emitError(
1154           "the update operation inside the region must be a binary operation "
1155           "and that update operation must have the region argument as an "
1156           "operand");
1157     binop = convertBinOpToAtomic(*innerUpdateOp);
1158 
1159     isXBinopExpr = innerUpdateOp->getOperand(0) ==
1160                    atomicUpdateOp.getRegion().getArgument(0);
1161 
1162     mlirExpr = (isXBinopExpr ? innerUpdateOp->getOperand(1)
1163                              : innerUpdateOp->getOperand(0));
1164   }
1165 
1166   llvm::Value *llvmExpr = moduleTranslation.lookupValue(mlirExpr);
1167   llvm::Value *llvmX =
1168       moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().x());
1169   llvm::Value *llvmV =
1170       moduleTranslation.lookupValue(atomicCaptureOp.getAtomicReadOp().v());
1171   auto mlirXType = atomicCaptureOp.getAtomicReadOp()
1172                        .x()
1173                        .getType()
1174                        .cast<LLVM::LLVMPointerType>();
1175   llvm::Type *llvmXElementType =
1176       moduleTranslation.convertType(mlirXType.getElementType());
1177   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {llvmX, llvmXElementType,
1178                                                       /*isSigned=*/false,
1179                                                       /*isVolatile=*/false};
1180   llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {llvmV, llvmXElementType,
1181                                                       /*isSigned=*/false,
1182                                                       /*isVolatile=*/false};
1183 
1184   llvm::AtomicOrdering atomicOrdering =
1185       convertAtomicOrdering(atomicCaptureOp.memory_order_val());
1186 
1187   LogicalResult updateGenStatus = success();
1188   auto updateFn = [&](llvm::Value *atomicx,
1189                       llvm::IRBuilder<> &builder) -> llvm::Value * {
1190     if (atomicWriteOp)
1191       return moduleTranslation.lookupValue(atomicWriteOp.value());
1192     Block &bb = *atomicUpdateOp.region().begin();
1193     moduleTranslation.mapValue(*atomicUpdateOp.region().args_begin(), atomicx);
1194     moduleTranslation.mapBlock(&bb, builder.GetInsertBlock());
1195     if (failed(moduleTranslation.convertBlock(bb, true, builder))) {
1196       updateGenStatus = (atomicUpdateOp.emitError()
1197                          << "unable to convert update operation to llvm IR");
1198       return nullptr;
1199     }
1200     omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
1201     assert(yieldop && yieldop.results().size() == 1 &&
1202            "terminator must be omp.yield op and it must have exactly one "
1203            "argument");
1204     return moduleTranslation.lookupValue(yieldop.results()[0]);
1205   };
1206 
1207   // Handle ambiguous alloca, if any.
1208   auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1209   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1210   builder.restoreIP(ompBuilder->createAtomicCapture(
1211       ompLoc, allocaIP, llvmAtomicX, llvmAtomicV, llvmExpr, atomicOrdering,
1212       binop, updateFn, atomicUpdateOp, isPostfixUpdate, isXBinopExpr));
1213   return updateGenStatus;
1214 }
1215 
1216 /// Converts an OpenMP reduction operation using OpenMPIRBuilder. Expects the
1217 /// mapping between reduction variables and their private equivalents to have
1218 /// been stored on the ModuleTranslation stack. Currently only supports
1219 /// reduction within WsLoopOp, but can be easily extended.
1220 static LogicalResult
1221 convertOmpReductionOp(omp::ReductionOp reductionOp,
1222                       llvm::IRBuilderBase &builder,
1223                       LLVM::ModuleTranslation &moduleTranslation) {
1224   // Find the declaration that corresponds to the reduction op.
1225   auto reductionContainer = reductionOp->getParentOfType<omp::WsLoopOp>();
1226   omp::ReductionDeclareOp declaration =
1227       findReductionDecl(reductionContainer, reductionOp);
1228   assert(declaration && "could not find reduction declaration");
1229 
1230   // Retrieve the mapping between reduction variables and their private
1231   // equivalents.
1232   const DenseMap<Value, llvm::Value *> *reductionVariableMap = nullptr;
1233   moduleTranslation.stackWalk<OpenMPVarMappingStackFrame>(
1234       [&](const OpenMPVarMappingStackFrame &frame) {
1235         reductionVariableMap = &frame.mapping;
1236         return WalkResult::interrupt();
1237       });
1238   assert(reductionVariableMap && "couldn't find private reduction variables");
1239 
1240   // Translate the reduction operation by emitting the body of the corresponding
1241   // reduction declaration.
1242   Region &reductionRegion = declaration.reductionRegion();
1243   llvm::Value *privateReductionVar =
1244       reductionVariableMap->lookup(reductionOp.accumulator());
1245   llvm::Value *reductionVal = builder.CreateLoad(
1246       moduleTranslation.convertType(reductionOp.operand().getType()),
1247       privateReductionVar);
1248 
1249   moduleTranslation.mapValue(reductionRegion.front().getArgument(0),
1250                              reductionVal);
1251   moduleTranslation.mapValue(
1252       reductionRegion.front().getArgument(1),
1253       moduleTranslation.lookupValue(reductionOp.operand()));
1254 
1255   SmallVector<llvm::Value *> phis;
1256   if (failed(inlineConvertOmpRegions(reductionRegion, "omp.reduction.body",
1257                                      builder, moduleTranslation, &phis)))
1258     return failure();
1259   assert(phis.size() == 1 && "expected one value to be yielded from "
1260                              "the reduction body declaration region");
1261   builder.CreateStore(phis[0], privateReductionVar);
1262   return success();
1263 }
1264 
1265 /// Converts an OpenMP Threadprivate operation into LLVM IR using
1266 /// OpenMPIRBuilder.
1267 static LogicalResult
1268 convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
1269                         LLVM::ModuleTranslation &moduleTranslation) {
1270   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1271   auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
1272 
1273   Value symAddr = threadprivateOp.sym_addr();
1274   auto *symOp = symAddr.getDefiningOp();
1275   if (!isa<LLVM::AddressOfOp>(symOp))
1276     return opInst.emitError("Addressing symbol not found");
1277   LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
1278 
1279   LLVM::GlobalOp global = addressOfOp.getGlobal();
1280   llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
1281   llvm::Value *data =
1282       builder.CreateBitCast(globalValue, builder.getInt8PtrTy());
1283   llvm::Type *type = globalValue->getValueType();
1284   llvm::TypeSize typeSize =
1285       builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
1286           type);
1287   llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedSize());
1288   llvm::StringRef suffix = llvm::StringRef(".cache", 6);
1289   std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
1290   // Emit runtime function and bitcast its type (i8*) to real data type.
1291   llvm::Value *callInst =
1292       moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
1293           ompLoc, data, size, cacheName);
1294   llvm::Value *result = builder.CreateBitCast(callInst, globalValue->getType());
1295   moduleTranslation.mapValue(opInst.getResult(0), result);
1296   return success();
1297 }
1298 
1299 namespace {
1300 
1301 /// Implementation of the dialect interface that converts operations belonging
1302 /// to the OpenMP dialect to LLVM IR.
1303 class OpenMPDialectLLVMIRTranslationInterface
1304     : public LLVMTranslationDialectInterface {
1305 public:
1306   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
1307 
1308   /// Translates the given operation to LLVM IR using the provided IR builder
1309   /// and saving the state in `moduleTranslation`.
1310   LogicalResult
1311   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
1312                    LLVM::ModuleTranslation &moduleTranslation) const final;
1313 };
1314 
1315 } // namespace
1316 
1317 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
1318 /// (including OpenMP runtime calls).
1319 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
1320     Operation *op, llvm::IRBuilderBase &builder,
1321     LLVM::ModuleTranslation &moduleTranslation) const {
1322 
1323   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1324 
1325   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
1326       .Case([&](omp::BarrierOp) {
1327         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1328         return success();
1329       })
1330       .Case([&](omp::TaskwaitOp) {
1331         ompBuilder->createTaskwait(builder.saveIP());
1332         return success();
1333       })
1334       .Case([&](omp::TaskyieldOp) {
1335         ompBuilder->createTaskyield(builder.saveIP());
1336         return success();
1337       })
1338       .Case([&](omp::FlushOp) {
1339         // No support in Openmp runtime function (__kmpc_flush) to accept
1340         // the argument list.
1341         // OpenMP standard states the following:
1342         //  "An implementation may implement a flush with a list by ignoring
1343         //   the list, and treating it the same as a flush without a list."
1344         //
1345         // The argument list is discarded so that, flush with a list is treated
1346         // same as a flush without a list.
1347         ompBuilder->createFlush(builder.saveIP());
1348         return success();
1349       })
1350       .Case([&](omp::ParallelOp op) {
1351         return convertOmpParallel(op, builder, moduleTranslation);
1352       })
1353       .Case([&](omp::ReductionOp reductionOp) {
1354         return convertOmpReductionOp(reductionOp, builder, moduleTranslation);
1355       })
1356       .Case([&](omp::MasterOp) {
1357         return convertOmpMaster(*op, builder, moduleTranslation);
1358       })
1359       .Case([&](omp::CriticalOp) {
1360         return convertOmpCritical(*op, builder, moduleTranslation);
1361       })
1362       .Case([&](omp::OrderedRegionOp) {
1363         return convertOmpOrderedRegion(*op, builder, moduleTranslation);
1364       })
1365       .Case([&](omp::OrderedOp) {
1366         return convertOmpOrdered(*op, builder, moduleTranslation);
1367       })
1368       .Case([&](omp::WsLoopOp) {
1369         return convertOmpWsLoop(*op, builder, moduleTranslation);
1370       })
1371       .Case([&](omp::SimdLoopOp) {
1372         return convertOmpSimdLoop(*op, builder, moduleTranslation);
1373       })
1374       .Case([&](omp::AtomicReadOp) {
1375         return convertOmpAtomicRead(*op, builder, moduleTranslation);
1376       })
1377       .Case([&](omp::AtomicWriteOp) {
1378         return convertOmpAtomicWrite(*op, builder, moduleTranslation);
1379       })
1380       .Case([&](omp::AtomicUpdateOp op) {
1381         return convertOmpAtomicUpdate(op, builder, moduleTranslation);
1382       })
1383       .Case([&](omp::AtomicCaptureOp op) {
1384         return convertOmpAtomicCapture(op, builder, moduleTranslation);
1385       })
1386       .Case([&](omp::SectionsOp) {
1387         return convertOmpSections(*op, builder, moduleTranslation);
1388       })
1389       .Case([&](omp::SingleOp op) {
1390         return convertOmpSingle(op, builder, moduleTranslation);
1391       })
1392       .Case<omp::YieldOp, omp::TerminatorOp, omp::ReductionDeclareOp,
1393             omp::CriticalDeclareOp>([](auto op) {
1394         // `yield` and `terminator` can be just omitted. The block structure
1395         // was created in the region that handles their parent operation.
1396         // `reduction.declare` will be used by reductions and is not
1397         // converted directly, skip it.
1398         // `critical.declare` is only used to declare names of critical
1399         // sections which will be used by `critical` ops and hence can be
1400         // ignored for lowering. The OpenMP IRBuilder will create unique
1401         // name for critical section names.
1402         return success();
1403       })
1404       .Case([&](omp::ThreadprivateOp) {
1405         return convertOmpThreadprivate(*op, builder, moduleTranslation);
1406       })
1407       .Default([&](Operation *inst) {
1408         return inst->emitError("unsupported OpenMP operation: ")
1409                << inst->getName();
1410       });
1411 }
1412 
1413 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
1414   registry.insert<omp::OpenMPDialect>();
1415   registry.addExtension(+[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
1416     dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
1417   });
1418 }
1419 
1420 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
1421   DialectRegistry registry;
1422   registerOpenMPDialectTranslation(registry);
1423   context.appendDialectRegistry(registry);
1424 }
1425