1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
22 #include "llvm/IR/IRBuilder.h"
23 
24 using namespace mlir;
25 
26 /// Converts the given region that appears within an OpenMP dialect operation to
27 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
28 /// region, and a branch from any block with an successor-less OpenMP terminator
29 /// to `continuationBlock`.
30 static void convertOmpOpRegions(Region &region, StringRef blockName,
31                                 llvm::BasicBlock &sourceBlock,
32                                 llvm::BasicBlock &continuationBlock,
33                                 llvm::IRBuilderBase &builder,
34                                 LLVM::ModuleTranslation &moduleTranslation,
35                                 LogicalResult &bodyGenStatus) {
36   llvm::LLVMContext &llvmContext = builder.getContext();
37   for (Block &bb : region) {
38     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
39         llvmContext, blockName, builder.GetInsertBlock()->getParent());
40     moduleTranslation.mapBlock(&bb, llvmBB);
41   }
42 
43   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
44 
45   // Convert blocks one by one in topological order to ensure
46   // defs are converted before uses.
47   llvm::SetVector<Block *> blocks =
48       LLVM::detail::getTopologicallySortedBlocks(region);
49   for (Block *bb : blocks) {
50     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
51     // Retarget the branch of the entry block to the entry block of the
52     // converted region (regions are single-entry).
53     if (bb->isEntryBlock()) {
54       assert(sourceTerminator->getNumSuccessors() == 1 &&
55              "provided entry block has multiple successors");
56       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
57              "ContinuationBlock is not the successor of the entry block");
58       sourceTerminator->setSuccessor(0, llvmBB);
59     }
60 
61     llvm::IRBuilderBase::InsertPointGuard guard(builder);
62     if (failed(
63             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
64       bodyGenStatus = failure();
65       return;
66     }
67 
68     // Special handling for `omp.yield` and `omp.terminator` (we may have more
69     // than one): they return the control to the parent OpenMP dialect operation
70     // so replace them with the branch to the continuation block. We handle this
71     // here to avoid relying inter-function communication through the
72     // ModuleTranslation class to set up the correct insertion point. This is
73     // also consistent with MLIR's idiom of handling special region terminators
74     // in the same code that handles the region-owning operation.
75     if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
76       builder.CreateBr(&continuationBlock);
77   }
78   // Finally, after all blocks have been traversed and values mapped,
79   // connect the PHI nodes to the results of preceding blocks.
80   LLVM::detail::connectPHINodes(region, moduleTranslation);
81 }
82 
83 /// Converts the OpenMP parallel operation to LLVM IR.
84 static LogicalResult
85 convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
86                    LLVM::ModuleTranslation &moduleTranslation) {
87   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
88   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
89   // relying on captured variables.
90   LogicalResult bodyGenStatus = success();
91 
92   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
93                        llvm::BasicBlock &continuationBlock) {
94     // ParallelOp has only one region associated with it.
95     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
96     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
97                         continuationBlock, builder, moduleTranslation,
98                         bodyGenStatus);
99   };
100 
101   // TODO: Perform appropriate actions according to the data-sharing
102   // attribute (shared, private, firstprivate, ...) of variables.
103   // Currently defaults to shared.
104   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
105                     llvm::Value &, llvm::Value &vPtr,
106                     llvm::Value *&replacementValue) -> InsertPointTy {
107     replacementValue = &vPtr;
108 
109     return codeGenIP;
110   };
111 
112   // TODO: Perform finalization actions for variables. This has to be
113   // called for variables which have destructors/finalizers.
114   auto finiCB = [&](InsertPointTy codeGenIP) {};
115 
116   llvm::Value *ifCond = nullptr;
117   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
118     ifCond = moduleTranslation.lookupValue(ifExprVar);
119   llvm::Value *numThreads = nullptr;
120   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
121     numThreads = moduleTranslation.lookupValue(numThreadsVar);
122   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
123   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
124     pbKind = llvm::omp::getProcBindKind(bind.getValue());
125   // TODO: Is the Parallel construct cancellable?
126   bool isCancellable = false;
127   // TODO: Determine the actual alloca insertion point, e.g., the function
128   // entry or the alloca insertion point as provided by the body callback
129   // above.
130   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
131   if (failed(bodyGenStatus))
132     return failure();
133   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
134       builder.saveIP(), builder.getCurrentDebugLocation());
135   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
136       ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
137       isCancellable));
138   return success();
139 }
140 
141 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
142 static LogicalResult
143 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
144                  LLVM::ModuleTranslation &moduleTranslation) {
145   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
146   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
147   // relying on captured variables.
148   LogicalResult bodyGenStatus = success();
149 
150   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
151                        llvm::BasicBlock &continuationBlock) {
152     // MasterOp has only one region associated with it.
153     auto &region = cast<omp::MasterOp>(opInst).getRegion();
154     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
155                         continuationBlock, builder, moduleTranslation,
156                         bodyGenStatus);
157   };
158 
159   // TODO: Perform finalization actions for variables. This has to be
160   // called for variables which have destructors/finalizers.
161   auto finiCB = [&](InsertPointTy codeGenIP) {};
162 
163   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
164       builder.saveIP(), builder.getCurrentDebugLocation());
165   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
166       ompLoc, bodyGenCB, finiCB));
167   return success();
168 }
169 
170 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
171 static LogicalResult
172 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
173                  LLVM::ModuleTranslation &moduleTranslation) {
174   auto loop = cast<omp::WsLoopOp>(opInst);
175   // TODO: this should be in the op verifier instead.
176   if (loop.lowerBound().empty())
177     return failure();
178 
179   if (loop.getNumLoops() != 1)
180     return opInst.emitOpError("collapsed loops not yet supported");
181 
182   if (loop.schedule_val().hasValue() &&
183       omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
184           omp::ClauseScheduleKind::Static)
185     return opInst.emitOpError(
186         "only static (default) loop schedule is currently supported");
187 
188   // Find the loop configuration.
189   llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
190   llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
191   llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
192   llvm::Type *ivType = step->getType();
193   llvm::Value *chunk =
194       loop.schedule_chunk_var()
195           ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
196           : llvm::ConstantInt::get(ivType, 1);
197 
198   // Set up the source location value for OpenMP runtime.
199   llvm::DISubprogram *subprogram =
200       builder.GetInsertBlock()->getParent()->getSubprogram();
201   const llvm::DILocation *diLoc =
202       moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
203   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
204                                                     llvm::DebugLoc(diLoc));
205 
206   // Generator of the canonical loop body. Produces an SESE region of basic
207   // blocks.
208   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
209   // relying on captured variables.
210   LogicalResult bodyGenStatus = success();
211   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
212     llvm::IRBuilder<>::InsertPointGuard guard(builder);
213 
214     // Make sure further conversions know about the induction variable.
215     moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
216 
217     llvm::BasicBlock *entryBlock = ip.getBlock();
218     llvm::BasicBlock *exitBlock =
219         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
220 
221     // Convert the body of the loop.
222     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
223                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
224   };
225 
226   // Delegate actual loop construction to the OpenMP IRBuilder.
227   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
228   // i.e. it has a positive step, uses signed integer semantics. Reconsider
229   // this code when WsLoop clearly supports more cases.
230   llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
231   llvm::CanonicalLoopInfo *loopInfo =
232       moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
233           ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
234           /*InclusiveStop=*/loop.inclusive());
235   if (failed(bodyGenStatus))
236     return failure();
237 
238   // TODO: get the alloca insertion point from the parallel operation builder.
239   // If we insert the at the top of the current function, they will be passed as
240   // extra arguments into the function the parallel operation builder outlines.
241   // Put them at the start of the current block for now.
242   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
243       insertBlock, insertBlock->getFirstInsertionPt());
244   loopInfo = moduleTranslation.getOpenMPBuilder()->createStaticWorkshareLoop(
245       ompLoc, loopInfo, allocaIP, !loop.nowait(), chunk);
246 
247   // Continue building IR after the loop.
248   builder.restoreIP(loopInfo->getAfterIP());
249   return success();
250 }
251 
252 namespace {
253 
254 /// Implementation of the dialect interface that converts operations belonging
255 /// to the OpenMP dialect to LLVM IR.
256 class OpenMPDialectLLVMIRTranslationInterface
257     : public LLVMTranslationDialectInterface {
258 public:
259   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
260 
261   /// Translates the given operation to LLVM IR using the provided IR builder
262   /// and saving the state in `moduleTranslation`.
263   LogicalResult
264   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
265                    LLVM::ModuleTranslation &moduleTranslation) const final;
266 };
267 
268 } // end namespace
269 
270 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
271 /// (including OpenMP runtime calls).
272 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
273     Operation *op, llvm::IRBuilderBase &builder,
274     LLVM::ModuleTranslation &moduleTranslation) const {
275 
276   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
277 
278   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
279       .Case([&](omp::BarrierOp) {
280         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
281         return success();
282       })
283       .Case([&](omp::TaskwaitOp) {
284         ompBuilder->createTaskwait(builder.saveIP());
285         return success();
286       })
287       .Case([&](omp::TaskyieldOp) {
288         ompBuilder->createTaskyield(builder.saveIP());
289         return success();
290       })
291       .Case([&](omp::FlushOp) {
292         // No support in Openmp runtime function (__kmpc_flush) to accept
293         // the argument list.
294         // OpenMP standard states the following:
295         //  "An implementation may implement a flush with a list by ignoring
296         //   the list, and treating it the same as a flush without a list."
297         //
298         // The argument list is discarded so that, flush with a list is treated
299         // same as a flush without a list.
300         ompBuilder->createFlush(builder.saveIP());
301         return success();
302       })
303       .Case([&](omp::ParallelOp) {
304         return convertOmpParallel(*op, builder, moduleTranslation);
305       })
306       .Case([&](omp::MasterOp) {
307         return convertOmpMaster(*op, builder, moduleTranslation);
308       })
309       .Case([&](omp::WsLoopOp) {
310         return convertOmpWsLoop(*op, builder, moduleTranslation);
311       })
312       .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
313         // `yield` and `terminator` can be just omitted. The block structure was
314         // created in the function that handles their parent operation.
315         assert(op->getNumOperands() == 0 &&
316                "unexpected OpenMP terminator with operands");
317         return success();
318       })
319       .Default([&](Operation *inst) {
320         return inst->emitError("unsupported OpenMP operation: ")
321                << inst->getName();
322       });
323 }
324 
325 void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
326   registry.insert<omp::OpenMPDialect>();
327   registry.addDialectInterface<omp::OpenMPDialect,
328                                OpenMPDialectLLVMIRTranslationInterface>();
329 }
330 
331 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
332   DialectRegistry registry;
333   registerOpenMPDialectTranslation(registry);
334   context.appendDialectRegistry(registry);
335 }
336