1 //===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenMP dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" 14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 15 #include "mlir/IR/Operation.h" 16 #include "mlir/Support/LLVM.h" 17 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 18 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" 22 #include "llvm/IR/IRBuilder.h" 23 24 using namespace mlir; 25 26 namespace { 27 /// ModuleTranslation stack frame for OpenMP operations. This keeps track of the 28 /// insertion points for allocas. 29 class OpenMPAllocaStackFrame 30 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> { 31 public: 32 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP) 33 : allocaInsertPoint(allocaIP) {} 34 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; 35 }; 36 } // namespace 37 38 /// Find the insertion point for allocas given the current insertion point for 39 /// normal operations in the builder. 40 static llvm::OpenMPIRBuilder::InsertPointTy 41 findAllocaInsertPoint(llvm::IRBuilderBase &builder, 42 const LLVM::ModuleTranslation &moduleTranslation) { 43 // If there is an alloca insertion point on stack, i.e. we are in a nested 44 // operation and a specific point was provided by some surrounding operation, 45 // use it. 46 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint; 47 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>( 48 [&](const OpenMPAllocaStackFrame &frame) { 49 allocaInsertPoint = frame.allocaInsertPoint; 50 return WalkResult::interrupt(); 51 }); 52 if (walkResult.wasInterrupted()) 53 return allocaInsertPoint; 54 55 // Otherwise, insert to the entry block of the surrounding function. 56 llvm::BasicBlock &funcEntryBlock = 57 builder.GetInsertBlock()->getParent()->getEntryBlock(); 58 return llvm::OpenMPIRBuilder::InsertPointTy( 59 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt()); 60 } 61 62 /// Converts the given region that appears within an OpenMP dialect operation to 63 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the 64 /// region, and a branch from any block with an successor-less OpenMP terminator 65 /// to `continuationBlock`. 66 static void convertOmpOpRegions(Region ®ion, StringRef blockName, 67 llvm::BasicBlock &sourceBlock, 68 llvm::BasicBlock &continuationBlock, 69 llvm::IRBuilderBase &builder, 70 LLVM::ModuleTranslation &moduleTranslation, 71 LogicalResult &bodyGenStatus) { 72 llvm::LLVMContext &llvmContext = builder.getContext(); 73 for (Block &bb : region) { 74 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( 75 llvmContext, blockName, builder.GetInsertBlock()->getParent()); 76 moduleTranslation.mapBlock(&bb, llvmBB); 77 } 78 79 llvm::Instruction *sourceTerminator = sourceBlock.getTerminator(); 80 81 // Convert blocks one by one in topological order to ensure 82 // defs are converted before uses. 83 SetVector<Block *> blocks = 84 LLVM::detail::getTopologicallySortedBlocks(region); 85 for (Block *bb : blocks) { 86 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb); 87 // Retarget the branch of the entry block to the entry block of the 88 // converted region (regions are single-entry). 89 if (bb->isEntryBlock()) { 90 assert(sourceTerminator->getNumSuccessors() == 1 && 91 "provided entry block has multiple successors"); 92 assert(sourceTerminator->getSuccessor(0) == &continuationBlock && 93 "ContinuationBlock is not the successor of the entry block"); 94 sourceTerminator->setSuccessor(0, llvmBB); 95 } 96 97 llvm::IRBuilderBase::InsertPointGuard guard(builder); 98 if (failed( 99 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) { 100 bodyGenStatus = failure(); 101 return; 102 } 103 104 // Special handling for `omp.yield` and `omp.terminator` (we may have more 105 // than one): they return the control to the parent OpenMP dialect operation 106 // so replace them with the branch to the continuation block. We handle this 107 // here to avoid relying inter-function communication through the 108 // ModuleTranslation class to set up the correct insertion point. This is 109 // also consistent with MLIR's idiom of handling special region terminators 110 // in the same code that handles the region-owning operation. 111 if (isa<omp::TerminatorOp, omp::YieldOp>(bb->getTerminator())) 112 builder.CreateBr(&continuationBlock); 113 } 114 // Finally, after all blocks have been traversed and values mapped, 115 // connect the PHI nodes to the results of preceding blocks. 116 LLVM::detail::connectPHINodes(region, moduleTranslation); 117 } 118 119 /// Converts the OpenMP parallel operation to LLVM IR. 120 static LogicalResult 121 convertOmpParallel(Operation &opInst, llvm::IRBuilderBase &builder, 122 LLVM::ModuleTranslation &moduleTranslation) { 123 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 124 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 125 // relying on captured variables. 126 LogicalResult bodyGenStatus = success(); 127 128 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 129 llvm::BasicBlock &continuationBlock) { 130 // Save the alloca insertion point on ModuleTranslation stack for use in 131 // nested regions. 132 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame( 133 moduleTranslation, allocaIP); 134 135 // ParallelOp has only one region associated with it. 136 auto ®ion = cast<omp::ParallelOp>(opInst).getRegion(); 137 convertOmpOpRegions(region, "omp.par.region", *codeGenIP.getBlock(), 138 continuationBlock, builder, moduleTranslation, 139 bodyGenStatus); 140 }; 141 142 // TODO: Perform appropriate actions according to the data-sharing 143 // attribute (shared, private, firstprivate, ...) of variables. 144 // Currently defaults to shared. 145 auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 146 llvm::Value &, llvm::Value &vPtr, 147 llvm::Value *&replacementValue) -> InsertPointTy { 148 replacementValue = &vPtr; 149 150 return codeGenIP; 151 }; 152 153 // TODO: Perform finalization actions for variables. This has to be 154 // called for variables which have destructors/finalizers. 155 auto finiCB = [&](InsertPointTy codeGenIP) {}; 156 157 llvm::Value *ifCond = nullptr; 158 if (auto ifExprVar = cast<omp::ParallelOp>(opInst).if_expr_var()) 159 ifCond = moduleTranslation.lookupValue(ifExprVar); 160 llvm::Value *numThreads = nullptr; 161 if (auto numThreadsVar = cast<omp::ParallelOp>(opInst).num_threads_var()) 162 numThreads = moduleTranslation.lookupValue(numThreadsVar); 163 llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default; 164 if (auto bind = cast<omp::ParallelOp>(opInst).proc_bind_val()) 165 pbKind = llvm::omp::getProcBindKind(bind.getValue()); 166 // TODO: Is the Parallel construct cancellable? 167 bool isCancellable = false; 168 169 llvm::OpenMPIRBuilder::LocationDescription ompLoc( 170 builder.saveIP(), builder.getCurrentDebugLocation()); 171 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createParallel( 172 ompLoc, findAllocaInsertPoint(builder, moduleTranslation), bodyGenCB, 173 privCB, finiCB, ifCond, numThreads, pbKind, isCancellable)); 174 175 return bodyGenStatus; 176 } 177 178 /// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder. 179 static LogicalResult 180 convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder, 181 LLVM::ModuleTranslation &moduleTranslation) { 182 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 183 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 184 // relying on captured variables. 185 LogicalResult bodyGenStatus = success(); 186 187 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 188 llvm::BasicBlock &continuationBlock) { 189 // MasterOp has only one region associated with it. 190 auto ®ion = cast<omp::MasterOp>(opInst).getRegion(); 191 convertOmpOpRegions(region, "omp.master.region", *codeGenIP.getBlock(), 192 continuationBlock, builder, moduleTranslation, 193 bodyGenStatus); 194 }; 195 196 // TODO: Perform finalization actions for variables. This has to be 197 // called for variables which have destructors/finalizers. 198 auto finiCB = [&](InsertPointTy codeGenIP) {}; 199 200 llvm::OpenMPIRBuilder::LocationDescription ompLoc( 201 builder.saveIP(), builder.getCurrentDebugLocation()); 202 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createMaster( 203 ompLoc, bodyGenCB, finiCB)); 204 return success(); 205 } 206 207 /// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder. 208 static LogicalResult 209 convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder, 210 LLVM::ModuleTranslation &moduleTranslation) { 211 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; 212 auto criticalOp = cast<omp::CriticalOp>(opInst); 213 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 214 // relying on captured variables. 215 LogicalResult bodyGenStatus = success(); 216 217 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, 218 llvm::BasicBlock &continuationBlock) { 219 // CriticalOp has only one region associated with it. 220 auto ®ion = cast<omp::CriticalOp>(opInst).getRegion(); 221 convertOmpOpRegions(region, "omp.critical.region", *codeGenIP.getBlock(), 222 continuationBlock, builder, moduleTranslation, 223 bodyGenStatus); 224 }; 225 226 // TODO: Perform finalization actions for variables. This has to be 227 // called for variables which have destructors/finalizers. 228 auto finiCB = [&](InsertPointTy codeGenIP) {}; 229 230 llvm::OpenMPIRBuilder::LocationDescription ompLoc( 231 builder.saveIP(), builder.getCurrentDebugLocation()); 232 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); 233 llvm::Constant *hint = nullptr; 234 if (criticalOp.hint().hasValue()) { 235 hint = 236 llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 237 static_cast<int>(criticalOp.hint().getValue())); 238 } else { 239 hint = llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), 0); 240 } 241 builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createCritical( 242 ompLoc, bodyGenCB, finiCB, criticalOp.name().getValueOr(""), hint)); 243 return success(); 244 } 245 246 /// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder. 247 static LogicalResult 248 convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder, 249 LLVM::ModuleTranslation &moduleTranslation) { 250 auto loop = cast<omp::WsLoopOp>(opInst); 251 // TODO: this should be in the op verifier instead. 252 if (loop.lowerBound().empty()) 253 return failure(); 254 255 // Static is the default. 256 omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static; 257 if (loop.schedule_val().hasValue()) 258 schedule = 259 *omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()); 260 261 // Set up the source location value for OpenMP runtime. 262 llvm::DISubprogram *subprogram = 263 builder.GetInsertBlock()->getParent()->getSubprogram(); 264 const llvm::DILocation *diLoc = 265 moduleTranslation.translateLoc(opInst.getLoc(), subprogram); 266 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(), 267 llvm::DebugLoc(diLoc)); 268 269 // Generator of the canonical loop body. 270 // TODO: support error propagation in OpenMPIRBuilder and use it instead of 271 // relying on captured variables. 272 SmallVector<llvm::CanonicalLoopInfo *> loopInfos; 273 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints; 274 LogicalResult bodyGenStatus = success(); 275 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) { 276 // Make sure further conversions know about the induction variable. 277 moduleTranslation.mapValue( 278 loop.getRegion().front().getArgument(loopInfos.size()), iv); 279 280 // Capture the body insertion point for use in nested loops. BodyIP of the 281 // CanonicalLoopInfo always points to the beginning of the entry block of 282 // the body. 283 bodyInsertPoints.push_back(ip); 284 285 if (loopInfos.size() != loop.getNumLoops() - 1) 286 return; 287 288 // Convert the body of the loop. 289 llvm::BasicBlock *entryBlock = ip.getBlock(); 290 llvm::BasicBlock *exitBlock = 291 entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit"); 292 convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock, 293 *exitBlock, builder, moduleTranslation, bodyGenStatus); 294 }; 295 296 // Delegate actual loop construction to the OpenMP IRBuilder. 297 // TODO: this currently assumes WsLoop is semantically similar to SCF loop, 298 // i.e. it has a positive step, uses signed integer semantics. Reconsider 299 // this code when WsLoop clearly supports more cases. 300 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 301 for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) { 302 llvm::Value *lowerBound = 303 moduleTranslation.lookupValue(loop.lowerBound()[i]); 304 llvm::Value *upperBound = 305 moduleTranslation.lookupValue(loop.upperBound()[i]); 306 llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]); 307 308 // Make sure loop trip count are emitted in the preheader of the outermost 309 // loop at the latest so that they are all available for the new collapsed 310 // loop will be created below. 311 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc; 312 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP; 313 if (i != 0) { 314 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(), 315 llvm::DebugLoc(diLoc)); 316 computeIP = loopInfos.front()->getPreheaderIP(); 317 } 318 loopInfos.push_back(ompBuilder->createCanonicalLoop( 319 loc, bodyGen, lowerBound, upperBound, step, 320 /*IsSigned=*/true, loop.inclusive(), computeIP)); 321 322 if (failed(bodyGenStatus)) 323 return failure(); 324 } 325 326 // Collapse loops. Store the insertion point because LoopInfos may get 327 // invalidated. 328 llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP(); 329 llvm::CanonicalLoopInfo *loopInfo = 330 ompBuilder->collapseLoops(diLoc, loopInfos, {}); 331 332 // Find the loop configuration. 333 llvm::Type *ivType = loopInfo->getIndVar()->getType(); 334 llvm::Value *chunk = 335 loop.schedule_chunk_var() 336 ? moduleTranslation.lookupValue(loop.schedule_chunk_var()) 337 : llvm::ConstantInt::get(ivType, 1); 338 llvm::OpenMPIRBuilder::InsertPointTy allocaIP = 339 findAllocaInsertPoint(builder, moduleTranslation); 340 if (schedule == omp::ClauseScheduleKind::Static) { 341 ompBuilder->applyStaticWorkshareLoop(ompLoc.DL, loopInfo, allocaIP, 342 !loop.nowait(), chunk); 343 } else { 344 llvm::omp::OMPScheduleType schedType; 345 switch (schedule) { 346 case omp::ClauseScheduleKind::Dynamic: 347 schedType = llvm::omp::OMPScheduleType::DynamicChunked; 348 break; 349 case omp::ClauseScheduleKind::Guided: 350 schedType = llvm::omp::OMPScheduleType::GuidedChunked; 351 break; 352 case omp::ClauseScheduleKind::Auto: 353 schedType = llvm::omp::OMPScheduleType::Auto; 354 break; 355 case omp::ClauseScheduleKind::Runtime: 356 schedType = llvm::omp::OMPScheduleType::Runtime; 357 break; 358 default: 359 llvm_unreachable("Unknown schedule value"); 360 break; 361 } 362 363 ompBuilder->applyDynamicWorkshareLoop(ompLoc.DL, loopInfo, allocaIP, 364 schedType, !loop.nowait(), chunk); 365 } 366 367 // Continue building IR after the loop. Note that the LoopInfo returned by 368 // `collapseLoops` points inside the outermost loop and is intended for 369 // potential further loop transformations. Use the insertion point stored 370 // before collapsing loops instead. 371 builder.restoreIP(afterIP); 372 return success(); 373 } 374 375 namespace { 376 377 /// Implementation of the dialect interface that converts operations belonging 378 /// to the OpenMP dialect to LLVM IR. 379 class OpenMPDialectLLVMIRTranslationInterface 380 : public LLVMTranslationDialectInterface { 381 public: 382 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 383 384 /// Translates the given operation to LLVM IR using the provided IR builder 385 /// and saving the state in `moduleTranslation`. 386 LogicalResult 387 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 388 LLVM::ModuleTranslation &moduleTranslation) const final; 389 }; 390 391 } // end namespace 392 393 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR 394 /// (including OpenMP runtime calls). 395 LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( 396 Operation *op, llvm::IRBuilderBase &builder, 397 LLVM::ModuleTranslation &moduleTranslation) const { 398 399 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); 400 401 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 402 .Case([&](omp::BarrierOp) { 403 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier); 404 return success(); 405 }) 406 .Case([&](omp::TaskwaitOp) { 407 ompBuilder->createTaskwait(builder.saveIP()); 408 return success(); 409 }) 410 .Case([&](omp::TaskyieldOp) { 411 ompBuilder->createTaskyield(builder.saveIP()); 412 return success(); 413 }) 414 .Case([&](omp::FlushOp) { 415 // No support in Openmp runtime function (__kmpc_flush) to accept 416 // the argument list. 417 // OpenMP standard states the following: 418 // "An implementation may implement a flush with a list by ignoring 419 // the list, and treating it the same as a flush without a list." 420 // 421 // The argument list is discarded so that, flush with a list is treated 422 // same as a flush without a list. 423 ompBuilder->createFlush(builder.saveIP()); 424 return success(); 425 }) 426 .Case([&](omp::ParallelOp) { 427 return convertOmpParallel(*op, builder, moduleTranslation); 428 }) 429 .Case([&](omp::MasterOp) { 430 return convertOmpMaster(*op, builder, moduleTranslation); 431 }) 432 .Case([&](omp::CriticalOp) { 433 return convertOmpCritical(*op, builder, moduleTranslation); 434 }) 435 .Case([&](omp::WsLoopOp) { 436 return convertOmpWsLoop(*op, builder, moduleTranslation); 437 }) 438 .Case<omp::YieldOp, omp::TerminatorOp>([](auto op) { 439 // `yield` and `terminator` can be just omitted. The block structure was 440 // created in the function that handles their parent operation. 441 assert(op->getNumOperands() == 0 && 442 "unexpected OpenMP terminator with operands"); 443 return success(); 444 }) 445 .Default([&](Operation *inst) { 446 return inst->emitError("unsupported OpenMP operation: ") 447 << inst->getName(); 448 }); 449 } 450 451 void mlir::registerOpenMPDialectTranslation(DialectRegistry ®istry) { 452 registry.insert<omp::OpenMPDialect>(); 453 registry.addDialectInterface<omp::OpenMPDialect, 454 OpenMPDialectLLVMIRTranslationInterface>(); 455 } 456 457 void mlir::registerOpenMPDialectTranslation(MLIRContext &context) { 458 DialectRegistry registry; 459 registerOpenMPDialectTranslation(registry); 460 context.appendDialectRegistry(registry); 461 } 462