1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
22 #include "llvm/IR/IRBuilder.h"
23 
24 using namespace mlir;
25 
26 /// Converts the given region that appears within an OpenMP dialect operation to
27 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
28 /// region, and a branch from any block with an successor-less OpenMP terminator
29 /// to `continuationBlock`.
30 static void convertOmpOpRegions(Region &region, StringRef blockName,
31                                 llvm::BasicBlock &sourceBlock,
32                                 llvm::BasicBlock &continuationBlock,
33                                 llvm::IRBuilderBase &builder,
34                                 LLVM::ModuleTranslation &moduleTranslation,
35                                 LogicalResult &bodyGenStatus) {
36   llvm::LLVMContext &llvmContext = builder.getContext();
37   for (Block &bb : region) {
38     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
39         llvmContext, blockName, builder.GetInsertBlock()->getParent());
40     moduleTranslation.mapBlock(&bb, llvmBB);
41   }
42 
43   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
44 
45   // Convert blocks one by one in topological order to ensure
46   // defs are converted before uses.
47   llvm::SetVector<Block *> blocks =
48       LLVM::detail::getTopologicallySortedBlocks(region);
49   for (Block *bb : blocks) {
50     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
51     // Retarget the branch of the entry block to the entry block of the
52     // converted region (regions are single-entry).
53     if (bb->isEntryBlock()) {
54       assert(sourceTerminator->getNumSuccessors() == 1 &&
55              "provided entry block has multiple successors");
56       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
57              "ContinuationBlock is not the successor of the entry block");
58       sourceTerminator->setSuccessor(0, llvmBB);
59     }
60 
61     llvm::IRBuilderBase::InsertPointGuard guard(builder);
62     if (failed(
63             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
64       bodyGenStatus = failure();
65       return;
66     }
67 
68     // Special handling for `omp.yield` and `omp.terminator` (we may have more
69     // than one): they return the control to the parent OpenMP dialect operation
70     // so replace them with the branch to the continuation block. We handle this
71     // here to avoid relying inter-function communication through the
72     // ModuleTranslation class to set up the correct insertion point. This is
73     // also consistent with MLIR's idiom of handling special region terminators
74     // in the same code that handles the region-owning operation.
75     if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
76       builder.CreateBr(&continuationBlock);
77   }
78   // Finally, after all blocks have been traversed and values mapped,
79   // connect the PHI nodes to the results of preceding blocks.
80   LLVM::detail::connectPHINodes(region, moduleTranslation);
81 }
82 
83 /// Converts the OpenMP parallel operation to LLVM IR.
84 static LogicalResult
85 convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
86                    LLVM::ModuleTranslation &moduleTranslation) {
87   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
88   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
89   // relying on captured variables.
90   LogicalResult bodyGenStatus = success();
91 
92   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
93                        llvm::BasicBlock &continuationBlock) {
94     // ParallelOp has only one region associated with it.
95     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
96     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
97                         continuationBlock, builder, moduleTranslation,
98                         bodyGenStatus);
99   };
100 
101   // TODO: Perform appropriate actions according to the data-sharing
102   // attribute (shared, private, firstprivate, ...) of variables.
103   // Currently defaults to shared.
104   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
105                     llvm::Value &, llvm::Value &vPtr,
106                     llvm::Value *&replacementValue) -> InsertPointTy {
107     replacementValue = &vPtr;
108 
109     return codeGenIP;
110   };
111 
112   // TODO: Perform finalization actions for variables. This has to be
113   // called for variables which have destructors/finalizers.
114   auto finiCB = [&](InsertPointTy codeGenIP) {};
115 
116   llvm::Value *ifCond = nullptr;
117   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
118     ifCond = moduleTranslation.lookupValue(ifExprVar);
119   llvm::Value *numThreads = nullptr;
120   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
121     numThreads = moduleTranslation.lookupValue(numThreadsVar);
122   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
123   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
124     pbKind = llvm::omp::getProcBindKind(bind.getValue());
125   // TODO: Is the Parallel construct cancellable?
126   bool isCancellable = false;
127   // TODO: Determine the actual alloca insertion point, e.g., the function
128   // entry or the alloca insertion point as provided by the body callback
129   // above.
130   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
131   if (failed(bodyGenStatus))
132     return failure();
133   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
134       builder.saveIP(), builder.getCurrentDebugLocation());
135   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
136       ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
137       isCancellable));
138   return success();
139 }
140 
141 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
142 static LogicalResult
143 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
144                  LLVM::ModuleTranslation &moduleTranslation) {
145   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
146   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
147   // relying on captured variables.
148   LogicalResult bodyGenStatus = success();
149 
150   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
151                        llvm::BasicBlock &continuationBlock) {
152     // MasterOp has only one region associated with it.
153     auto &region = cast<omp::MasterOp>(opInst).getRegion();
154     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
155                         continuationBlock, builder, moduleTranslation,
156                         bodyGenStatus);
157   };
158 
159   // TODO: Perform finalization actions for variables. This has to be
160   // called for variables which have destructors/finalizers.
161   auto finiCB = [&](InsertPointTy codeGenIP) {};
162 
163   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
164       builder.saveIP(), builder.getCurrentDebugLocation());
165   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
166       ompLoc, bodyGenCB, finiCB));
167   return success();
168 }
169 
170 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
171 LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
172                                LLVM::ModuleTranslation &moduleTranslation) {
173   auto loop = cast<omp::WsLoopOp>(opInst);
174   // TODO: this should be in the op verifier instead.
175   if (loop.lowerBound().empty())
176     return failure();
177 
178   if (loop.getNumLoops() != 1)
179     return opInst.emitOpError("collapsed loops not yet supported");
180 
181   if (loop.schedule_val().hasValue() &&
182       omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
183           omp::ClauseScheduleKind::Static)
184     return opInst.emitOpError(
185         "only static (default) loop schedule is currently supported");
186 
187   // Find the loop configuration.
188   llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
189   llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
190   llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
191   llvm::Type *ivType = step->getType();
192   llvm::Value *chunk =
193       loop.schedule_chunk_var()
194           ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
195           : llvm::ConstantInt::get(ivType, 1);
196 
197   // Set up the source location value for OpenMP runtime.
198   llvm::DISubprogram *subprogram =
199       builder.GetInsertBlock()->getParent()->getSubprogram();
200   const llvm::DILocation *diLoc =
201       moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
202   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
203                                                     llvm::DebugLoc(diLoc));
204 
205   // Generator of the canonical loop body. Produces an SESE region of basic
206   // blocks.
207   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
208   // relying on captured variables.
209   LogicalResult bodyGenStatus = success();
210   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
211     llvm::IRBuilder<>::InsertPointGuard guard(builder);
212 
213     // Make sure further conversions know about the induction variable.
214     moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
215 
216     llvm::BasicBlock *entryBlock = ip.getBlock();
217     llvm::BasicBlock *exitBlock =
218         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
219 
220     // Convert the body of the loop.
221     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
222                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
223   };
224 
225   // Delegate actual loop construction to the OpenMP IRBuilder.
226   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
227   // i.e. it has a positive step, uses signed integer semantics. Reconsider
228   // this code when WsLoop clearly supports more cases.
229   llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
230   llvm::CanonicalLoopInfo *loopInfo =
231       moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
232           ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
233           /*InclusiveStop=*/loop.inclusive());
234   if (failed(bodyGenStatus))
235     return failure();
236 
237   // TODO: get the alloca insertion point from the parallel operation builder.
238   // If we insert the at the top of the current function, they will be passed as
239   // extra arguments into the function the parallel operation builder outlines.
240   // Put them at the start of the current block for now.
241   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
242       insertBlock, insertBlock->getFirstInsertionPt());
243   loopInfo = moduleTranslation.getOpenMPBuilder()->createStaticWorkshareLoop(
244       ompLoc, loopInfo, allocaIP, !loop.nowait(), chunk);
245 
246   // Continue building IR after the loop.
247   builder.restoreIP(loopInfo->getAfterIP());
248   return success();
249 }
250 
251 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
252 /// (including OpenMP runtime calls).
253 LogicalResult mlir::OpenMPDialectLLVMIRTranslationInterface::convertOperation(
254     Operation *op, llvm::IRBuilderBase &builder,
255     LLVM::ModuleTranslation &moduleTranslation) const {
256 
257   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
258 
259   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
260       .Case([&](omp::BarrierOp) {
261         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
262         return success();
263       })
264       .Case([&](omp::TaskwaitOp) {
265         ompBuilder->createTaskwait(builder.saveIP());
266         return success();
267       })
268       .Case([&](omp::TaskyieldOp) {
269         ompBuilder->createTaskyield(builder.saveIP());
270         return success();
271       })
272       .Case([&](omp::FlushOp) {
273         // No support in Openmp runtime function (__kmpc_flush) to accept
274         // the argument list.
275         // OpenMP standard states the following:
276         //  "An implementation may implement a flush with a list by ignoring
277         //   the list, and treating it the same as a flush without a list."
278         //
279         // The argument list is discarded so that, flush with a list is treated
280         // same as a flush without a list.
281         ompBuilder->createFlush(builder.saveIP());
282         return success();
283       })
284       .Case([&](omp::ParallelOp) {
285         return convertOmpParallel(*op, builder, moduleTranslation);
286       })
287       .Case([&](omp::MasterOp) {
288         return convertOmpMaster(*op, builder, moduleTranslation);
289       })
290       .Case([&](omp::WsLoopOp) {
291         return convertOmpWsLoop(*op, builder, moduleTranslation);
292       })
293       .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
294         // `yield` and `terminator` can be just omitted. The block structure was
295         // created in the function that handles their parent operation.
296         assert(op->getNumOperands() == 0 &&
297                "unexpected OpenMP terminator with operands");
298         return success();
299       })
300       .Default([&](Operation *inst) {
301         return inst->emitError("unsupported OpenMP operation: ")
302                << inst->getName();
303       });
304 }
305