1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenMP dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 #include "mlir/IR/Operation.h"
16 #include "mlir/Support/LLVM.h"
17 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
18 
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
22 #include "llvm/IR/IRBuilder.h"
23 
24 using namespace mlir;
25 
26 /// Converts the given region that appears within an OpenMP dialect operation to
27 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
28 /// region, and a branch from any block with an successor-less OpenMP terminator
29 /// to `continuationBlock`.
30 static void convertOmpOpRegions(Region &region, StringRef blockName,
31                                 llvm::BasicBlock &sourceBlock,
32                                 llvm::BasicBlock &continuationBlock,
33                                 llvm::IRBuilderBase &builder,
34                                 LLVM::ModuleTranslation &moduleTranslation,
35                                 LogicalResult &bodyGenStatus) {
36   llvm::LLVMContext &llvmContext = builder.getContext();
37   for (Block &bb : region) {
38     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
39         llvmContext, blockName, builder.GetInsertBlock()->getParent());
40     moduleTranslation.mapBlock(&bb, llvmBB);
41   }
42 
43   llvm::Instruction *sourceTerminator = sourceBlock.getTerminator();
44 
45   // Convert blocks one by one in topological order to ensure
46   // defs are converted before uses.
47   llvm::SetVector<Block *> blocks =
48       LLVM::detail::getTopologicallySortedBlocks(region);
49   for (Block *bb : blocks) {
50     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
51     // Retarget the branch of the entry block to the entry block of the
52     // converted region (regions are single-entry).
53     if (bb->isEntryBlock()) {
54       assert(sourceTerminator->getNumSuccessors() == 1 &&
55              "provided entry block has multiple successors");
56       assert(sourceTerminator->getSuccessor(0) == &continuationBlock &&
57              "ContinuationBlock is not the successor of the entry block");
58       sourceTerminator->setSuccessor(0, llvmBB);
59     }
60 
61     llvm::IRBuilder<>::InsertPointGuard guard(builder);
62     if (failed(moduleTranslation.convertBlock(
63             *bb, bb->isEntryBlock(),
64             // TODO: this downcast should be removed after all of
65             // ModuleTranslation migrated to using IRBuilderBase &; the cast is
66             // safe in practice because the builder always comes from
67             // ModuleTranslation itself that only uses this subclass.
68             static_cast<llvm::IRBuilder<> &>(builder)))) {
69       bodyGenStatus = failure();
70       return;
71     }
72 
73     // Special handling for `omp.yield` and `omp.terminator` (we may have more
74     // than one): they return the control to the parent OpenMP dialect operation
75     // so replace them with the branch to the continuation block. We handle this
76     // here to avoid relying inter-function communication through the
77     // ModuleTranslation class to set up the correct insertion point. This is
78     // also consistent with MLIR's idiom of handling special region terminators
79     // in the same code that handles the region-owning operation.
80     if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator()))
81       builder.CreateBr(&continuationBlock);
82   }
83   // Finally, after all blocks have been traversed and values mapped,
84   // connect the PHI nodes to the results of preceding blocks.
85   LLVM::detail::connectPHINodes(region, moduleTranslation);
86 }
87 
88 /// Converts the OpenMP parallel operation to LLVM IR.
89 static LogicalResult
90 convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder,
91                    LLVM::ModuleTranslation &moduleTranslation) {
92   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
93   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
94   // relying on captured variables.
95   LogicalResult bodyGenStatus = success();
96 
97   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
98                        llvm::BasicBlock &continuationBlock) {
99     // ParallelOp has only one region associated with it.
100     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
101     convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(),
102                         continuationBlock, builder, moduleTranslation,
103                         bodyGenStatus);
104   };
105 
106   // TODO: Perform appropriate actions according to the data-sharing
107   // attribute (shared, private, firstprivate, ...) of variables.
108   // Currently defaults to shared.
109   auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
110                     llvm::Value &, llvm::Value &vPtr,
111                     llvm::Value *&replacementValue) -> InsertPointTy {
112     replacementValue = &vPtr;
113 
114     return codeGenIP;
115   };
116 
117   // TODO: Perform finalization actions for variables. This has to be
118   // called for variables which have destructors/finalizers.
119   auto finiCB = [&](InsertPointTy codeGenIP) {};
120 
121   llvm::Value *ifCond = nullptr;
122   if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var())
123     ifCond = moduleTranslation.lookupValue(ifExprVar);
124   llvm::Value *numThreads = nullptr;
125   if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var())
126     numThreads = moduleTranslation.lookupValue(numThreadsVar);
127   llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default;
128   if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val())
129     pbKind = llvm::omp::getProcBindKind(bind.getValue());
130   // TODO: Is the Parallel construct cancellable?
131   bool isCancellable = false;
132   // TODO: Determine the actual alloca insertion point, e.g., the function
133   // entry or the alloca insertion point as provided by the body callback
134   // above.
135   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(builder.saveIP());
136   if (failed(bodyGenStatus))
137     return failure();
138   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
139       builder.saveIP(), builder.getCurrentDebugLocation());
140   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel(
141       ompLoc, allocaIP, bodyGenCB, privCB, finiCB, ifCond, numThreads, pbKind,
142       isCancellable));
143   return success();
144 }
145 
146 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
147 static LogicalResult
148 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
149                  LLVM::ModuleTranslation &moduleTranslation) {
150   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
151   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
152   // relying on captured variables.
153   LogicalResult bodyGenStatus = success();
154 
155   auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
156                        llvm::BasicBlock &continuationBlock) {
157     // MasterOp has only one region associated with it.
158     auto &region = cast<omp::MasterOp>(opInst).getRegion();
159     convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(),
160                         continuationBlock, builder, moduleTranslation,
161                         bodyGenStatus);
162   };
163 
164   // TODO: Perform finalization actions for variables. This has to be
165   // called for variables which have destructors/finalizers.
166   auto finiCB = [&](InsertPointTy codeGenIP) {};
167 
168   llvm::OpenMPIRBuilder::LocationDescription ompLoc(
169       builder.saveIP(), builder.getCurrentDebugLocation());
170   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster(
171       ompLoc, bodyGenCB, finiCB));
172   return success();
173 }
174 
175 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
176 LogicalResult convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
177                                LLVM::ModuleTranslation &moduleTranslation) {
178   auto loop = cast<omp::WsLoopOp>(opInst);
179   // TODO: this should be in the op verifier instead.
180   if (loop.lowerBound().empty())
181     return failure();
182 
183   if (loop.getNumLoops() != 1)
184     return opInst.emitOpError("collapsed loops not yet supported");
185 
186   if (loop.schedule_val().hasValue() &&
187       omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
188           omp::ClauseScheduleKind::Static)
189     return opInst.emitOpError(
190         "only static (default) loop schedule is currently supported");
191 
192   // Find the loop configuration.
193   llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
194   llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
195   llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
196   llvm::Type *ivType = step->getType();
197   llvm::Value *chunk =
198       loop.schedule_chunk_var()
199           ? moduleTranslation.lookupValue(loop.schedule_chunk_var())
200           : llvm::ConstantInt::get(ivType, 1);
201 
202   // Set up the source location value for OpenMP runtime.
203   llvm::DISubprogram *subprogram =
204       builder.GetInsertBlock()->getParent()->getSubprogram();
205   const llvm::DILocation *diLoc =
206       moduleTranslation.translateLoc(opInst.getLoc(), subprogram);
207   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
208                                                     llvm::DebugLoc(diLoc));
209 
210   // Generator of the canonical loop body. Produces an SESE region of basic
211   // blocks.
212   // TODO: support error propagation in OpenMPIRBuilder and use it instead of
213   // relying on captured variables.
214   LogicalResult bodyGenStatus = success();
215   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
216     llvm::IRBuilder<>::InsertPointGuard guard(builder);
217 
218     // Make sure further conversions know about the induction variable.
219     moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
220 
221     llvm::BasicBlock *entryBlock = ip.getBlock();
222     llvm::BasicBlock *exitBlock =
223         entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
224 
225     // Convert the body of the loop.
226     convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
227                         *exitBlock, builder, moduleTranslation, bodyGenStatus);
228   };
229 
230   // Delegate actual loop construction to the OpenMP IRBuilder.
231   // TODO: this currently assumes WsLoop is semantically similar to SCF loop,
232   // i.e. it has a positive step, uses signed integer semantics. Reconsider
233   // this code when WsLoop clearly supports more cases.
234   llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
235   llvm::CanonicalLoopInfo *loopInfo =
236       moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
237           ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
238           /*InclusiveStop=*/loop.inclusive());
239   if (failed(bodyGenStatus))
240     return failure();
241 
242   // TODO: get the alloca insertion point from the parallel operation builder.
243   // If we insert the at the top of the current function, they will be passed as
244   // extra arguments into the function the parallel operation builder outlines.
245   // Put them at the start of the current block for now.
246   llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
247       insertBlock, insertBlock->getFirstInsertionPt());
248   loopInfo = moduleTranslation.getOpenMPBuilder()->createStaticWorkshareLoop(
249       ompLoc, loopInfo, allocaIP, !loop.nowait(), chunk);
250 
251   // Continue building IR after the loop.
252   builder.restoreIP(loopInfo->getAfterIP());
253   return success();
254 }
255 
256 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
257 /// (including OpenMP runtime calls).
258 LogicalResult mlir::OpenMPDialectLLVMIRTranslationInterface::convertOperation(
259     Operation *op, llvm::IRBuilderBase &builder,
260     LLVM::ModuleTranslation &moduleTranslation) const {
261 
262   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
263 
264   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
265       .Case([&](omp::BarrierOp) {
266         ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
267         return success();
268       })
269       .Case([&](omp::TaskwaitOp) {
270         ompBuilder->createTaskwait(builder.saveIP());
271         return success();
272       })
273       .Case([&](omp::TaskyieldOp) {
274         ompBuilder->createTaskyield(builder.saveIP());
275         return success();
276       })
277       .Case([&](omp::FlushOp) {
278         // No support in Openmp runtime function (__kmpc_flush) to accept
279         // the argument list.
280         // OpenMP standard states the following:
281         //  "An implementation may implement a flush with a list by ignoring
282         //   the list, and treating it the same as a flush without a list."
283         //
284         // The argument list is discarded so that, flush with a list is treated
285         // same as a flush without a list.
286         ompBuilder->createFlush(builder.saveIP());
287         return success();
288       })
289       .Case([&](omp::ParallelOp) {
290         return convertOmpParallel(*op, builder, moduleTranslation);
291       })
292       .Case([&](omp::MasterOp) {
293         return convertOmpMaster(*op, builder, moduleTranslation);
294       })
295       .Case([&](omp::WsLoopOp) {
296         return convertOmpWsLoop(*op, builder, moduleTranslation);
297       })
298       .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) {
299         // `yield` and `terminator` can be just omitted. The block structure was
300         // created in the function that handles their parent operation.
301         assert(op->getNumOperands() == 0 &&
302                "unexpected OpenMP terminator with operands");
303         return success();
304       })
305       .Default([&](Operation *inst) {
306         return inst->emitError("unsupported OpenMP operation: ")
307                << inst->getName();
308       });
309 }
310