1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
22 #include "llvm/IR/IRBuilder.h"
23 
24 using namespace mlir;
25 
26 /// Converts the given region that appears within an OpenMP dialect operation to
27 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
28 /// region, and a branch from any block with an successor-less OpenMP terminator
29 /// to `continuationBlock`.
30 static void convertOmpOpRegions(Region &region, StringRef blockName,
31                                 llvm::BasicBlock &sourceBlock,
32                                 llvm::BasicBlock &continuationBlock,
33                                 llvm::IRBuilderBase &builder,
34                                 LLVM::ModuleTranslation &moduleTranslation,
35                                 LogicalResult &bodyGenStatus) {
36   llvm::LLVMContext &llvmContext = builder.getContext();
37   for (Block &bb : region) {
38     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
39         llvmContext, blockName, builder.GetInsertBlock()->getParent());
40     moduleTranslation.mapBlock(&bb, llvmBB);
41   }
42 
43   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
44 
45   // Convert blocks one by one in topological order to ensure
46   // defs are converted before uses.
47   SetVector<Block *> blocks =
48       LLVM::detail::getTopologicallySortedBlocks(region);
49   for (Block *bb : blocks) {
50     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
51     // Retarget the branch of the entry block to the entry block of the
52     // converted region (regions are single-entry).
53     if (bb->isEntryBlock()) {
54       assert(sourceTerminator->getNumSuccessors() == 1 &&
55              "provided entry block has multiple successors");
56       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
57              "ContinuationBlock is not the successor of the entry block");
58       sourceTerminator->setSuccessor(0, llvmBB);
59     }
60 
61     llvm::IRBuilderBase::InsertPointGuard guard(builder);
62     if (failed(
63             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
64       bodyGenStatus = failure();
65       return;
66     }
67 
68     // Special handling for `omp.yield` and `omp.terminator` (we may have more
69     // than one): they return the control to the parent OpenMP dialect operation
70     // so replace them with the branch to the continuation block. We handle this
71     // here to avoid relying inter-function communication through the
72     // ModuleTranslation class to set up the correct insertion point. This is
73     // also consistent with MLIR's idiom of handling special region terminators
74     // in the same code that handles the region-owning operation.
75     if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
76       builder.CreateBr(&continuationBlock);
77   }
78   // Finally, after all blocks have been traversed and values mapped,
79   // connect the PHI nodes to the results of preceding blocks.
80   LLVM::detail::connectPHINodes(region, moduleTranslation);
81 }
82 
83 /// Converts the OpenMP parallel operation to LLVM IR.
84 static LogicalResult
85 convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
86                    LLVM::ModuleTranslation &moduleTranslation) {
87   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
88   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
89   // relying on captured variables.
90   LogicalResult bodyGenStatus = success();
91 
92   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
93                        llvm::BasicBlock &continuationBlock) {
94     // ParallelOp has only one region associated with it.
95     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
96     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
97                         continuationBlock, builder, moduleTranslation,
98                         bodyGenStatus);
99   };
100 
101   // TODO: Perform appropriate actions according to the data-sharing
102   // attribute (shared, private, firstprivate, ...) of variables.
103   // Currently defaults to shared.
104   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
105                     llvm::Value &, llvm::Value &vPtr,
106                     llvm::Value *&replacementValue) -> InsertPointTy {
107     replacementValue = &vPtr;
108 
109     return codeGenIP;
110   };
111 
112   // TODO: Perform finalization actions for variables. This has to be
113   // called for variables which have destructors/finalizers.
114   auto finiCB = [&](InsertPointTy codeGenIP) {};
115 
116   llvm::Value *ifCond = nullptr;
117   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
118     ifCond = moduleTranslation.lookupValue(ifExprVar);
119   llvm::Value *numThreads = nullptr;
120   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
121     numThreads = moduleTranslation.lookupValue(numThreadsVar);
122   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
123   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
124     pbKind = llvm::omp::getProcBindKind(bind.getValue());
125   // TODO: Is the Parallel construct cancellable?
126   bool isCancellable = false;
127   // TODO: Determine the actual alloca insertion point, e.g., the function
128   // entry or the alloca insertion point as provided by the body callback
129   // above.
130   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
131   if (failed(bodyGenStatus))
132     return failure();
133   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
134       builder.saveIP(), builder.getCurrentDebugLocation());
135   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
136       ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
137       isCancellable));
138   return success();
139 }
140 
141 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
142 static LogicalResult
143 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
144                  LLVM::ModuleTranslation &moduleTranslation) {
145   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
146   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
147   // relying on captured variables.
148   LogicalResult bodyGenStatus = success();
149 
150   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
151                        llvm::BasicBlock &continuationBlock) {
152     // MasterOp has only one region associated with it.
153     auto &region = cast<omp::MasterOp>(opInst).getRegion();
154     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
155                         continuationBlock, builder, moduleTranslation,
156                         bodyGenStatus);
157   };
158 
159   // TODO: Perform finalization actions for variables. This has to be
160   // called for variables which have destructors/finalizers.
161   auto finiCB = [&](InsertPointTy codeGenIP) {};
162 
163   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
164       builder.saveIP(), builder.getCurrentDebugLocation());
165   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
166       ompLoc, bodyGenCB, finiCB));
167   return success();
168 }
169 
170 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
171 static LogicalResult
172 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
173                  LLVM::ModuleTranslation &moduleTranslation) {
174   auto loop = cast<omp::WsLoopOp>(opInst);
175   // TODO: this should be in the op verifier instead.
176   if (loop.lowerBound().empty())
177     return failure();
178 
179   if (loop.getNumLoops() != 1)
180     return opInst.emitOpError("collapsed loops not yet supported");
181 
182   bool isStatic = true;
183 
184   if (loop.schedule_val().hasValue()) {
185     auto schedule =
186         omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue());
187     if (schedule != omp::ClauseScheduleKind::Static &&
188         schedule != omp::ClauseScheduleKind::Dynamic)
189       return opInst.emitOpError("only static (default) and dynamic loop "
190                                 "schedule is currently supported");
191     isStatic = (schedule == omp::ClauseScheduleKind::Static);
192   }
193 
194   // Find the loop configuration.
195   llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
196   llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
197   llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
198   llvm::Type *ivType = step->getType();
199   llvm::Value *chunk =
200       loop.schedule_chunk_var()
201           ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
202           : llvm::ConstantInt::get(ivType, 1);
203 
204   // Set up the source location value for OpenMP runtime.
205   llvm::DISubprogram *subprogram =
206       builder.GetInsertBlock()->getParent()->getSubprogram();
207   const llvm::DILocation *diLoc =
208       moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
209   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
210                                                     llvm::DebugLoc(diLoc));
211 
212   // Generator of the canonical loop body. Produces an SESE region of basic
213   // blocks.
214   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
215   // relying on captured variables.
216   LogicalResult bodyGenStatus = success();
217   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
218     llvm::IRBuilder<>::InsertPointGuard guard(builder);
219 
220     // Make sure further conversions know about the induction variable.
221     moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
222 
223     llvm::BasicBlock *entryBlock = ip.getBlock();
224     llvm::BasicBlock *exitBlock =
225         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
226 
227     // Convert the body of the loop.
228     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
229                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
230   };
231 
232   // Delegate actual loop construction to the OpenMP IRBuilder.
233   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
234   // i.e. it has a positive step, uses signed integer semantics. Reconsider
235   // this code when WsLoop clearly supports more cases.
236   llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
237   llvm::CanonicalLoopInfo *loopInfo =
238       moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
239           ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
240           /*InclusiveStop=*/loop.inclusive());
241   if (failed(bodyGenStatus))
242     return failure();
243 
244   // TODO: get the alloca insertion point from the parallel operation builder.
245   // If we insert the at the top of the current function, they will be passed as
246   // extra arguments into the function the parallel operation builder outlines.
247   // Put them at the start of the current block for now.
248   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
249       insertBlock, insertBlock->getFirstInsertionPt());
250   llvm::OpenMPIRBuilder::InsertPointTy afterIP;
251   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
252   if (isStatic) {
253     loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
254                                                      !loop.nowait(), chunk);
255     afterIP = loopInfo->getAfterIP();
256   } else {
257     afterIP = ompBuilder->createDynamicWorkshareLoop(ompLoc, loopInfo, allocaIP,
258                                                      !loop.nowait(), chunk);
259   }
260 
261   // Continue building IR after the loop.
262   builder.restoreIP(afterIP);
263   return success();
264 }
265 
266 namespace {
267 
268 /// Implementation of the dialect interface that converts operations belonging
269 /// to the OpenMP dialect to LLVM IR.
270 class OpenMPDialectLLVMIRTranslationInterface
271     : public LLVMTranslationDialectInterface {
272 public:
273   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
274 
275   /// Translates the given operation to LLVM IR using the provided IR builder
276   /// and saving the state in `moduleTranslation`.
277   LogicalResult
278   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
279                    LLVM::ModuleTranslation &moduleTranslation) const final;
280 };
281 
282 } // end namespace
283 
284 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
285 /// (including OpenMP runtime calls).
286 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
287     Operation *op, llvm::IRBuilderBase &builder,
288     LLVM::ModuleTranslation &moduleTranslation) const {
289 
290   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
291 
292   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
293       .Case([&](omp::BarrierOp) {
294         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
295         return success();
296       })
297       .Case([&](omp::TaskwaitOp) {
298         ompBuilder->createTaskwait(builder.saveIP());
299         return success();
300       })
301       .Case([&](omp::TaskyieldOp) {
302         ompBuilder->createTaskyield(builder.saveIP());
303         return success();
304       })
305       .Case([&](omp::FlushOp) {
306         // No support in Openmp runtime function (__kmpc_flush) to accept
307         // the argument list.
308         // OpenMP standard states the following:
309         //  "An implementation may implement a flush with a list by ignoring
310         //   the list, and treating it the same as a flush without a list."
311         //
312         // The argument list is discarded so that, flush with a list is treated
313         // same as a flush without a list.
314         ompBuilder->createFlush(builder.saveIP());
315         return success();
316       })
317       .Case([&](omp::ParallelOp) {
318         return convertOmpParallel(*op, builder, moduleTranslation);
319       })
320       .Case([&](omp::MasterOp) {
321         return convertOmpMaster(*op, builder, moduleTranslation);
322       })
323       .Case([&](omp::WsLoopOp) {
324         return convertOmpWsLoop(*op, builder, moduleTranslation);
325       })
326       .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
327         // `yield` and `terminator` can be just omitted. The block structure was
328         // created in the function that handles their parent operation.
329         assert(op->getNumOperands() == 0 &&
330                "unexpected OpenMP terminator with operands");
331         return success();
332       })
333       .Default([&](Operation *inst) {
334         return inst->emitError("unsupported OpenMP operation: ")
335                << inst->getName();
336       });
337 }
338 
339 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
340   registry.insert<omp::OpenMPDialect>();
341   registry.addDialectInterface<omp::OpenMPDialect,
342                                OpenMPDialectLLVMIRTranslationInterface>();
343 }
344 
345 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
346   DialectRegistry registry;
347   registerOpenMPDialectTranslation(registry);
348   context.appendDialectRegistry(registry);
349 }
350