1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
22 #include "llvm/IR/IRBuilder.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
28 /// insertion points for allocas.
29 class OpenMPAllocaStackFrame
30     : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
31 public:
32   explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
33       : allocaInsertPoint(allocaIP) {}
34   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
35 };
36 } // namespace
37 
38 /// Find the insertion point for allocas given the current insertion point for
39 /// normal operations in the builder.
40 static llvm::OpenMPIRBuilder::InsertPointTy
41 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
42                       const LLVM::ModuleTranslation &moduleTranslation) {
43   // If there is an alloca insertion point on stack, i.e. we are in a nested
44   // operation and a specific point was provided by some surrounding operation,
45   // use it.
46   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
47   WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
48       [&](const OpenMPAllocaStackFrame &frame) {
49         allocaInsertPoint = frame.allocaInsertPoint;
50         return WalkResult::interrupt();
51       });
52   if (walkResult.wasInterrupted())
53     return allocaInsertPoint;
54 
55   // Otherwise, insert to the entry block of the surrounding function.
56   llvm::BasicBlock &funcEntryBlock =
57       builder.GetInsertBlock()->getParent()->getEntryBlock();
58   return llvm::OpenMPIRBuilder::InsertPointTy(
59       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
60 }
61 
62 /// Converts the given region that appears within an OpenMP dialect operation to
63 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
64 /// region, and a branch from any block with an successor-less OpenMP terminator
65 /// to `continuationBlock`.
66 static void convertOmpOpRegions(Region &region, StringRef blockName,
67                                 llvm::BasicBlock &sourceBlock,
68                                 llvm::BasicBlock &continuationBlock,
69                                 llvm::IRBuilderBase &builder,
70                                 LLVM::ModuleTranslation &moduleTranslation,
71                                 LogicalResult &bodyGenStatus) {
72   llvm::LLVMContext &llvmContext = builder.getContext();
73   for (Block &bb : region) {
74     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
75         llvmContext, blockName, builder.GetInsertBlock()->getParent());
76     moduleTranslation.mapBlock(&bb, llvmBB);
77   }
78 
79   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
80 
81   // Convert blocks one by one in topological order to ensure
82   // defs are converted before uses.
83   SetVector<Block *> blocks =
84       LLVM::detail::getTopologicallySortedBlocks(region);
85   for (Block *bb : blocks) {
86     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
87     // Retarget the branch of the entry block to the entry block of the
88     // converted region (regions are single-entry).
89     if (bb->isEntryBlock()) {
90       assert(sourceTerminator->getNumSuccessors() == 1 &&
91              "provided entry block has multiple successors");
92       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
93              "ContinuationBlock is not the successor of the entry block");
94       sourceTerminator->setSuccessor(0, llvmBB);
95     }
96 
97     llvm::IRBuilderBase::InsertPointGuard guard(builder);
98     if (failed(
99             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
100       bodyGenStatus = failure();
101       return;
102     }
103 
104     // Special handling for `omp.yield` and `omp.terminator` (we may have more
105     // than one): they return the control to the parent OpenMP dialect operation
106     // so replace them with the branch to the continuation block. We handle this
107     // here to avoid relying inter-function communication through the
108     // ModuleTranslation class to set up the correct insertion point. This is
109     // also consistent with MLIR's idiom of handling special region terminators
110     // in the same code that handles the region-owning operation.
111     if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
112       builder.CreateBr(&continuationBlock);
113   }
114   // Finally, after all blocks have been traversed and values mapped,
115   // connect the PHI nodes to the results of preceding blocks.
116   LLVM::detail::connectPHINodes(region, moduleTranslation);
117 }
118 
119 /// Converts the OpenMP parallel operation to LLVM IR.
120 static LogicalResult
121 convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
122                    LLVM::ModuleTranslation &moduleTranslation) {
123   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
124   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
125   // relying on captured variables.
126   LogicalResult bodyGenStatus = success();
127 
128   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
129                        llvm::BasicBlock &continuationBlock) {
130     // Save the alloca insertion point on ModuleTranslation stack for use in
131     // nested regions.
132     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
133         moduleTranslation, allocaIP);
134 
135     // ParallelOp has only one region associated with it.
136     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
137     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
138                         continuationBlock, builder, moduleTranslation,
139                         bodyGenStatus);
140   };
141 
142   // TODO: Perform appropriate actions according to the data-sharing
143   // attribute (shared, private, firstprivate, ...) of variables.
144   // Currently defaults to shared.
145   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
146                     llvm::Value &, llvm::Value &vPtr,
147                     llvm::Value *&replacementValue) -> InsertPointTy {
148     replacementValue = &vPtr;
149 
150     return codeGenIP;
151   };
152 
153   // TODO: Perform finalization actions for variables. This has to be
154   // called for variables which have destructors/finalizers.
155   auto finiCB = [&](InsertPointTy codeGenIP) {};
156 
157   llvm::Value *ifCond = nullptr;
158   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
159     ifCond = moduleTranslation.lookupValue(ifExprVar);
160   llvm::Value *numThreads = nullptr;
161   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
162     numThreads = moduleTranslation.lookupValue(numThreadsVar);
163   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
164   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
165     pbKind = llvm::omp::getProcBindKind(bind.getValue());
166   // TODO: Is the Parallel construct cancellable?
167   bool isCancellable = false;
168 
169   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
170       builder.saveIP(), builder.getCurrentDebugLocation());
171   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
172       ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB,
173       privCB, finiCB, ifCond, numThreads, pbKind, isCancellable));
174 
175   return bodyGenStatus;
176 }
177 
178 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
179 static LogicalResult
180 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
181                  LLVM::ModuleTranslation &moduleTranslation) {
182   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
183   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
184   // relying on captured variables.
185   LogicalResult bodyGenStatus = success();
186 
187   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
188                        llvm::BasicBlock &continuationBlock) {
189     // MasterOp has only one region associated with it.
190     auto &region = cast<omp::MasterOp>(opInst).getRegion();
191     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
192                         continuationBlock, builder, moduleTranslation,
193                         bodyGenStatus);
194   };
195 
196   // TODO: Perform finalization actions for variables. This has to be
197   // called for variables which have destructors/finalizers.
198   auto finiCB = [&](InsertPointTy codeGenIP) {};
199 
200   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
201       builder.saveIP(), builder.getCurrentDebugLocation());
202   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
203       ompLoc, bodyGenCB, finiCB));
204   return success();
205 }
206 
207 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
208 static LogicalResult
209 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
210                    LLVM::ModuleTranslation &moduleTranslation) {
211   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
212   auto criticalOp = cast<omp::CriticalOp>(opInst);
213   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
214   // relying on captured variables.
215   LogicalResult bodyGenStatus = success();
216 
217   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
218                        llvm::BasicBlock &continuationBlock) {
219     // CriticalOp has only one region associated with it.
220     auto &region = cast<omp::CriticalOp>(opInst).getRegion();
221     convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(),
222                         continuationBlock, builder, moduleTranslation,
223                         bodyGenStatus);
224   };
225 
226   // TODO: Perform finalization actions for variables. This has to be
227   // called for variables which have destructors/finalizers.
228   auto finiCB = [&](InsertPointTy codeGenIP) {};
229 
230   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
231       builder.saveIP(), builder.getCurrentDebugLocation());
232   llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
233   llvm::Constant *hint = nullptr;
234   if (criticalOp.hint().hasValue()) {
235     hint =
236         llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext),
237                                static_cast<int>(criticalOp.hint().getValue()));
238   } else {
239     hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0);
240   }
241   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical(
242       ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint));
243   return success();
244 }
245 
246 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
247 static LogicalResult
248 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
249                  LLVM::ModuleTranslation &moduleTranslation) {
250   auto loop = cast<omp::WsLoopOp>(opInst);
251   // TODO: this should be in the op verifier instead.
252   if (loop.lowerBound().empty())
253     return failure();
254 
255   // Static is the default.
256   omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
257   if (loop.schedule_val().hasValue())
258     schedule =
259         *omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue());
260 
261   // Set up the source location value for OpenMP runtime.
262   llvm::DISubprogram *subprogram =
263       builder.GetInsertBlock()->getParent()->getSubprogram();
264   const llvm::DILocation *diLoc =
265       moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
266   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
267                                                     llvm::DebugLoc(diLoc));
268 
269   // Generator of the canonical loop body.
270   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
271   // relying on captured variables.
272   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
273   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
274   LogicalResult bodyGenStatus = success();
275   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
276     // Make sure further conversions know about the induction variable.
277     moduleTranslation.mapValue(
278         loop.getRegion().front().getArgument(loopInfos.size()), iv);
279 
280     // Capture the body insertion point for use in nested loops. BodyIP of the
281     // CanonicalLoopInfo always points to the beginning of the entry block of
282     // the body.
283     bodyInsertPoints.push_back(ip);
284 
285     if (loopInfos.size() != loop.getNumLoops() - 1)
286       return;
287 
288     // Convert the body of the loop.
289     llvm::BasicBlock *entryBlock = ip.getBlock();
290     llvm::BasicBlock *exitBlock =
291         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
292     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
293                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
294   };
295 
296   // Delegate actual loop construction to the OpenMP IRBuilder.
297   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
298   // i.e. it has a positive step, uses signed integer semantics. Reconsider
299   // this code when WsLoop clearly supports more cases.
300   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
301   for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
302     llvm::Value *lowerBound =
303         moduleTranslation.lookupValue(loop.lowerBound()[i]);
304     llvm::Value *upperBound =
305         moduleTranslation.lookupValue(loop.upperBound()[i]);
306     llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
307 
308     // Make sure loop trip count are emitted in the preheader of the outermost
309     // loop at the latest so that they are all available for the new collapsed
310     // loop will be created below.
311     llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
312     llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
313     if (i != 0) {
314       loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
315                                                        llvm::DebugLoc(diLoc));
316       computeIP = loopInfos.front()->getPreheaderIP();
317     }
318     loopInfos.push_back(ompBuilder->createCanonicalLoop(
319         loc, bodyGen, lowerBound, upperBound, step,
320         /*IsSigned=*/true, loop.inclusive(), computeIP));
321 
322     if (failed(bodyGenStatus))
323       return failure();
324   }
325 
326   // Collapse loops. Store the insertion point because LoopInfos may get
327   // invalidated.
328   llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
329   llvm::CanonicalLoopInfo *loopInfo =
330       ompBuilder->collapseLoops(diLoc, loopInfos, {});
331 
332   // Find the loop configuration.
333   llvm::Type *ivType = loopInfo->getIndVar()->getType();
334   llvm::Value *chunk =
335       loop.schedule_chunk_var()
336           ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
337           : llvm::ConstantInt::get(ivType, 1);
338   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
339       findAllocaInsertPoint(builder, moduleTranslation);
340   if (schedule == omp::ClauseScheduleKind::Static) {
341     ompBuilder->applyStaticWorkshareLoop(ompLoc.DL, loopInfo, allocaIP,
342                                          !loop.nowait(), chunk);
343   } else {
344     llvm::omp::OMPScheduleType schedType;
345     switch (schedule) {
346     case omp::ClauseScheduleKind::Dynamic:
347       schedType = llvm::omp::OMPScheduleType::DynamicChunked;
348       break;
349     case omp::ClauseScheduleKind::Guided:
350       schedType = llvm::omp::OMPScheduleType::GuidedChunked;
351       break;
352     case omp::ClauseScheduleKind::Auto:
353       schedType = llvm::omp::OMPScheduleType::Auto;
354       break;
355     case omp::ClauseScheduleKind::Runtime:
356       schedType = llvm::omp::OMPScheduleType::Runtime;
357       break;
358     default:
359       llvm_unreachable("Unknown schedule value");
360       break;
361     }
362 
363     ompBuilder->applyDynamicWorkshareLoop(ompLoc.DL, loopInfo, allocaIP,
364                                           schedType, !loop.nowait(), chunk);
365   }
366 
367   // Continue building IR after the loop. Note that the LoopInfo returned by
368   // `collapseLoops` points inside the outermost loop and is intended for
369   // potential further loop transformations. Use the insertion point stored
370   // before collapsing loops instead.
371   builder.restoreIP(afterIP);
372   return success();
373 }
374 
375 namespace {
376 
377 /// Implementation of the dialect interface that converts operations belonging
378 /// to the OpenMP dialect to LLVM IR.
379 class OpenMPDialectLLVMIRTranslationInterface
380     : public LLVMTranslationDialectInterface {
381 public:
382   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
383 
384   /// Translates the given operation to LLVM IR using the provided IR builder
385   /// and saving the state in `moduleTranslation`.
386   LogicalResult
387   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
388                    LLVM::ModuleTranslation &moduleTranslation) const final;
389 };
390 
391 } // end namespace
392 
393 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
394 /// (including OpenMP runtime calls).
395 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
396     Operation *op, llvm::IRBuilderBase &builder,
397     LLVM::ModuleTranslation &moduleTranslation) const {
398 
399   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
400 
401   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
402       .Case([&](omp::BarrierOp) {
403         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
404         return success();
405       })
406       .Case([&](omp::TaskwaitOp) {
407         ompBuilder->createTaskwait(builder.saveIP());
408         return success();
409       })
410       .Case([&](omp::TaskyieldOp) {
411         ompBuilder->createTaskyield(builder.saveIP());
412         return success();
413       })
414       .Case([&](omp::FlushOp) {
415         // No support in Openmp runtime function (__kmpc_flush) to accept
416         // the argument list.
417         // OpenMP standard states the following:
418         //  "An implementation may implement a flush with a list by ignoring
419         //   the list, and treating it the same as a flush without a list."
420         //
421         // The argument list is discarded so that, flush with a list is treated
422         // same as a flush without a list.
423         ompBuilder->createFlush(builder.saveIP());
424         return success();
425       })
426       .Case([&](omp::ParallelOp) {
427         return convertOmpParallel(*op, builder, moduleTranslation);
428       })
429       .Case([&](omp::MasterOp) {
430         return convertOmpMaster(*op, builder, moduleTranslation);
431       })
432       .Case([&](omp::CriticalOp) {
433         return convertOmpCritical(*op, builder, moduleTranslation);
434       })
435       .Case([&](omp::WsLoopOp) {
436         return convertOmpWsLoop(*op, builder, moduleTranslation);
437       })
438       .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
439         // `yield` and `terminator` can be just omitted. The block structure was
440         // created in the function that handles their parent operation.
441         assert(op->getNumOperands() == 0 &&
442                "unexpected OpenMP terminator with operands");
443         return success();
444       })
445       .Default([&](Operation *inst) {
446         return inst->emitError("unsupported OpenMP operation: ")
447                << inst->getName();
448       });
449 }
450 
451 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
452   registry.insert<omp::OpenMPDialect>();
453   registry.addDialectInterface<omp::OpenMPDialect,
454                                OpenMPDialectLLVMIRTranslationInterface>();
455 }
456 
457 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
458   DialectRegistry registry;
459   registerOpenMPDialectTranslation(registry);
460   context.appendDialectRegistry(registry);
461 }
462