1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenACC dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" 15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/OpenACC/OpenACC.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 22 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder; 30 31 //===----------------------------------------------------------------------===// 32 // Utility functions 33 //===----------------------------------------------------------------------===// 34 35 /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and 36 /// mapped to corresponding OpenACC flags. 37 static constexpr uint64_t kCreateFlag = 0x000; 38 static constexpr uint64_t kDeviceCopyinFlag = 0x001; 39 static constexpr uint64_t kHostCopyoutFlag = 0x002; 40 static constexpr uint64_t kCopyFlag = kDeviceCopyinFlag | kHostCopyoutFlag; 41 static constexpr uint64_t kPresentFlag = 0x1000; 42 static constexpr uint64_t kDeleteFlag = 0x008; 43 44 /// Default value for the device id 45 static constexpr int64_t kDefaultDevice = -1; 46 47 /// Create a constant string location from the MLIR Location information. 48 static llvm::Constant *createSourceLocStrFromLocation(Location loc, 49 OpenACCIRBuilder &builder, 50 StringRef name) { 51 if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) { 52 StringRef fileName = fileLoc.getFilename(); 53 unsigned lineNo = fileLoc.getLine(); 54 unsigned colNo = fileLoc.getColumn(); 55 return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); 56 } else { 57 std::string locStr; 58 llvm::raw_string_ostream locOS(locStr); 59 locOS << loc; 60 return builder.getOrCreateSrcLocStr(locOS.str()); 61 } 62 } 63 64 /// Create the location struct from the operation location information. 65 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, 66 Operation *op) { 67 auto loc = op->getLoc(); 68 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); 69 StringRef funcName = funcOp ? funcOp.getName() : "unknown"; 70 llvm::Constant *locStr = 71 createSourceLocStrFromLocation(loc, builder, funcName); 72 return builder.getOrCreateIdent(locStr); 73 } 74 75 /// Create a constant string representing the mapping information extracted from 76 /// the MLIR location information. 77 static llvm::Constant *createMappingInformation(Location loc, 78 OpenACCIRBuilder &builder) { 79 if (auto nameLoc = loc.dyn_cast<NameLoc>()) { 80 StringRef name = nameLoc.getName(); 81 return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); 82 } else { 83 return createSourceLocStrFromLocation(loc, builder, "unknown"); 84 } 85 } 86 87 /// Return the runtime function used to lower the given operation. 88 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, 89 Operation *op) { 90 return llvm::TypeSwitch<Operation *, llvm::Function *>(op) 91 .Case([&](acc::EnterDataOp) { 92 return builder.getOrCreateRuntimeFunctionPtr( 93 llvm::omp::OMPRTL___tgt_target_data_begin_mapper); 94 }) 95 .Case([&](acc::ExitDataOp) { 96 return builder.getOrCreateRuntimeFunctionPtr( 97 llvm::omp::OMPRTL___tgt_target_data_end_mapper); 98 }) 99 .Case([&](acc::UpdateOp) { 100 return builder.getOrCreateRuntimeFunctionPtr( 101 llvm::omp::OMPRTL___tgt_target_data_update_mapper); 102 }); 103 llvm_unreachable("Unknown OpenACC operation"); 104 } 105 106 /// Computes the size of type in bytes. 107 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, 108 llvm::Value *basePtr) { 109 llvm::LLVMContext &ctx = builder.getContext(); 110 llvm::Value *null = 111 llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); 112 llvm::Value *sizeGep = 113 builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); 114 llvm::Value *sizePtrToInt = 115 builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); 116 return sizePtrToInt; 117 } 118 119 /// Extract pointer, size and mapping information from operands 120 /// to populate the future functions arguments. 121 static LogicalResult 122 processOperands(llvm::IRBuilderBase &builder, 123 LLVM::ModuleTranslation &moduleTranslation, Operation *op, 124 ValueRange operands, unsigned totalNbOperand, 125 uint64_t operandFlag, SmallVector<uint64_t> &flags, 126 SmallVectorImpl<llvm::Constant *> &names, unsigned &index, 127 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { 128 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 129 llvm::LLVMContext &ctx = builder.getContext(); 130 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 131 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 132 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 133 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 134 135 for (Value data : operands) { 136 llvm::Value *dataValue = moduleTranslation.lookupValue(data); 137 138 llvm::Value *dataPtrBase; 139 llvm::Value *dataPtr; 140 llvm::Value *dataSize; 141 142 // Handle operands that were converted to DataDescriptor. 143 if (DataDescriptor::isValid(data)) { 144 dataPtrBase = 145 builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor); 146 dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); 147 dataSize = 148 builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); 149 } else if (data.getType().isa<LLVM::LLVMPointerType>()) { 150 dataPtrBase = dataValue; 151 dataPtr = dataValue; 152 dataSize = getSizeInBytes(builder, dataValue); 153 } else { 154 return op->emitOpError() 155 << "Data operand must be legalized before translation." 156 << "Unsupported type: " << data.getType(); 157 } 158 159 // Store base pointer extracted from operand into the i-th position of 160 // argBase. 161 llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( 162 arrI8PtrTy, mapperAllocas.ArgsBase, 163 {builder.getInt32(0), builder.getInt32(index)}); 164 llvm::Value *ptrBaseCast = builder.CreateBitCast( 165 ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); 166 builder.CreateStore(dataPtrBase, ptrBaseCast); 167 168 // Store pointer extracted from operand into the i-th position of args. 169 llvm::Value *ptrGEP = builder.CreateInBoundsGEP( 170 arrI8PtrTy, mapperAllocas.Args, 171 {builder.getInt32(0), builder.getInt32(index)}); 172 llvm::Value *ptrCast = 173 builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); 174 builder.CreateStore(dataPtr, ptrCast); 175 176 // Store size extracted from operand into the i-th position of argSizes. 177 llvm::Value *sizeGEP = builder.CreateInBoundsGEP( 178 arrI64Ty, mapperAllocas.ArgSizes, 179 {builder.getInt32(0), builder.getInt32(index)}); 180 builder.CreateStore(dataSize, sizeGEP); 181 182 flags.push_back(operandFlag); 183 llvm::Constant *mapName = 184 createMappingInformation(data.getLoc(), *accBuilder); 185 names.push_back(mapName); 186 ++index; 187 } 188 return success(); 189 } 190 191 /// Process data operands from acc::EnterDataOp 192 static LogicalResult 193 processDataOperands(llvm::IRBuilderBase &builder, 194 LLVM::ModuleTranslation &moduleTranslation, 195 acc::EnterDataOp op, SmallVector<uint64_t> &flags, 196 SmallVectorImpl<llvm::Constant *> &names, 197 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { 198 // TODO add `create_zero` and `attach` operands 199 200 unsigned index = 0; 201 202 // Create operands are handled as `alloc` call. 203 if (failed(processOperands(builder, moduleTranslation, op, 204 op.createOperands(), op.getNumDataOperands(), 205 kCreateFlag, flags, names, index, mapperAllocas))) 206 return failure(); 207 208 // Copyin operands are handled as `to` call. 209 if (failed(processOperands(builder, moduleTranslation, op, 210 op.copyinOperands(), op.getNumDataOperands(), 211 kDeviceCopyinFlag, flags, names, index, 212 mapperAllocas))) 213 return failure(); 214 215 return success(); 216 } 217 218 /// Process data operands from acc::ExitDataOp 219 static LogicalResult 220 processDataOperands(llvm::IRBuilderBase &builder, 221 LLVM::ModuleTranslation &moduleTranslation, 222 acc::ExitDataOp op, SmallVector<uint64_t> &flags, 223 SmallVectorImpl<llvm::Constant *> &names, 224 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { 225 // TODO add `detach` operands 226 227 unsigned index = 0; 228 229 // Delete operands are handled as `delete` call. 230 if (failed(processOperands(builder, moduleTranslation, op, 231 op.deleteOperands(), op.getNumDataOperands(), 232 kDeleteFlag, flags, names, index, mapperAllocas))) 233 return failure(); 234 235 // Copyout operands are handled as `from` call. 236 if (failed(processOperands(builder, moduleTranslation, op, 237 op.copyoutOperands(), op.getNumDataOperands(), 238 kHostCopyoutFlag, flags, names, index, 239 mapperAllocas))) 240 return failure(); 241 242 return success(); 243 } 244 245 /// Process data operands from acc::UpdateOp 246 static LogicalResult 247 processDataOperands(llvm::IRBuilderBase &builder, 248 LLVM::ModuleTranslation &moduleTranslation, 249 acc::UpdateOp op, SmallVector<uint64_t> &flags, 250 SmallVectorImpl<llvm::Constant *> &names, 251 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) { 252 unsigned index = 0; 253 254 // Host operands are handled as `from` call. 255 if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(), 256 op.getNumDataOperands(), kHostCopyoutFlag, flags, 257 names, index, mapperAllocas))) 258 return failure(); 259 260 // Device operands are handled as `to` call. 261 if (failed(processOperands(builder, moduleTranslation, op, 262 op.deviceOperands(), op.getNumDataOperands(), 263 kDeviceCopyinFlag, flags, names, index, 264 mapperAllocas))) 265 return failure(); 266 267 return success(); 268 } 269 270 //===----------------------------------------------------------------------===// 271 // Conversion functions 272 //===----------------------------------------------------------------------===// 273 274 /// Converts an OpenACC data operation into LLVM IR. 275 static LogicalResult convertDataOp(acc::DataOp &op, 276 llvm::IRBuilderBase &builder, 277 LLVM::ModuleTranslation &moduleTranslation) { 278 llvm::LLVMContext &ctx = builder.getContext(); 279 auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>(); 280 llvm::Function *enclosingFunction = 281 moduleTranslation.lookupFunction(enclosingFuncOp.getName()); 282 283 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 284 285 llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op); 286 287 llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr( 288 llvm::omp::OMPRTL___tgt_target_data_begin_mapper); 289 290 llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr( 291 llvm::omp::OMPRTL___tgt_target_data_end_mapper); 292 293 // Number of arguments in the data operation. 294 unsigned totalNbOperand = op.getNumDataOperands(); 295 296 struct OpenACCIRBuilder::MapperAllocas mapperAllocas; 297 OpenACCIRBuilder::InsertPointTy allocaIP( 298 &enclosingFunction->getEntryBlock(), 299 enclosingFunction->getEntryBlock().getFirstInsertionPt()); 300 accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, 301 mapperAllocas); 302 303 SmallVector<uint64_t> flags; 304 SmallVector<llvm::Constant *> names; 305 unsigned index = 0; 306 307 // TODO handle no_create, deviceptr and attach operands. 308 309 if (failed(processOperands(builder, moduleTranslation, op, op.copyOperands(), 310 totalNbOperand, kCopyFlag, flags, names, index, 311 mapperAllocas))) 312 return failure(); 313 314 if (failed(processOperands( 315 builder, moduleTranslation, op, op.copyinOperands(), totalNbOperand, 316 kDeviceCopyinFlag, flags, names, index, mapperAllocas))) 317 return failure(); 318 319 // TODO copyin readonly currenlty handled as copyin. Update when extension 320 // available. 321 if (failed(processOperands(builder, moduleTranslation, op, 322 op.copyinReadonlyOperands(), totalNbOperand, 323 kDeviceCopyinFlag, flags, names, index, 324 mapperAllocas))) 325 return failure(); 326 327 if (failed(processOperands( 328 builder, moduleTranslation, op, op.copyoutOperands(), totalNbOperand, 329 kHostCopyoutFlag, flags, names, index, mapperAllocas))) 330 return failure(); 331 332 // TODO copyout zero currenlty handled as copyout. Update when extension 333 // available. 334 if (failed(processOperands(builder, moduleTranslation, op, 335 op.copyoutZeroOperands(), totalNbOperand, 336 kHostCopyoutFlag, flags, names, index, 337 mapperAllocas))) 338 return failure(); 339 340 if (failed(processOperands(builder, moduleTranslation, op, 341 op.createOperands(), totalNbOperand, kCreateFlag, 342 flags, names, index, mapperAllocas))) 343 return failure(); 344 345 // TODO create zero currenlty handled as create. Update when extension 346 // available. 347 if (failed(processOperands(builder, moduleTranslation, op, 348 op.createZeroOperands(), totalNbOperand, 349 kCreateFlag, flags, names, index, mapperAllocas))) 350 return failure(); 351 352 if (failed(processOperands(builder, moduleTranslation, op, 353 op.presentOperands(), totalNbOperand, kPresentFlag, 354 flags, names, index, mapperAllocas))) 355 return failure(); 356 357 llvm::GlobalVariable *maptypes = 358 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); 359 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( 360 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), 361 maptypes, /*Idx0=*/0, /*Idx1=*/0); 362 363 llvm::GlobalVariable *mapnames = 364 accBuilder->createOffloadMapnames(names, ".offload_mapnames"); 365 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( 366 llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), 367 mapnames, /*Idx0=*/0, /*Idx1=*/0); 368 369 // Create call to start the data region. 370 accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo, 371 maptypesArg, mapnamesArg, mapperAllocas, 372 kDefaultDevice, totalNbOperand); 373 374 // Convert the region. 375 llvm::BasicBlock *entryBlock = nullptr; 376 377 for (Block &bb : op.region()) { 378 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create( 379 ctx, "acc.data", builder.GetInsertBlock()->getParent()); 380 if (entryBlock == nullptr) 381 entryBlock = llvmBB; 382 moduleTranslation.mapBlock(&bb, llvmBB); 383 } 384 385 auto afterDataRegion = builder.saveIP(); 386 387 llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock); 388 389 builder.restoreIP(afterDataRegion); 390 llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create( 391 ctx, "acc.end_data", builder.GetInsertBlock()->getParent()); 392 393 SetVector<Block *> blocks = 394 LLVM::detail::getTopologicallySortedBlocks(op.region()); 395 for (Block *bb : blocks) { 396 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb); 397 if (bb->isEntryBlock()) { 398 assert(sourceTerminator->getNumSuccessors() == 1 && 399 "provided entry block has multiple successors"); 400 sourceTerminator->setSuccessor(0, llvmBB); 401 } 402 403 if (failed( 404 moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) { 405 return failure(); 406 } 407 408 if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator())) 409 builder.CreateBr(endDataBlock); 410 } 411 412 // Create call to end the data region. 413 builder.SetInsertPoint(endDataBlock); 414 accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo, 415 maptypesArg, mapnamesArg, mapperAllocas, 416 kDefaultDevice, totalNbOperand); 417 418 return success(); 419 } 420 421 /// Converts an OpenACC standalone data operation into LLVM IR. 422 template <typename OpTy> 423 static LogicalResult 424 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, 425 LLVM::ModuleTranslation &moduleTranslation) { 426 auto enclosingFuncOp = 427 op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>(); 428 llvm::Function *enclosingFunction = 429 moduleTranslation.lookupFunction(enclosingFuncOp.getName()); 430 431 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 432 433 auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); 434 auto *mapperFunc = getAssociatedFunction(*accBuilder, op); 435 436 // Number of arguments in the enter_data operation. 437 unsigned totalNbOperand = op.getNumDataOperands(); 438 439 llvm::LLVMContext &ctx = builder.getContext(); 440 441 struct OpenACCIRBuilder::MapperAllocas mapperAllocas; 442 OpenACCIRBuilder::InsertPointTy allocaIP( 443 &enclosingFunction->getEntryBlock(), 444 enclosingFunction->getEntryBlock().getFirstInsertionPt()); 445 accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand, 446 mapperAllocas); 447 448 SmallVector<uint64_t> flags; 449 SmallVector<llvm::Constant *> names; 450 451 if (failed(processDataOperands(builder, moduleTranslation, op, flags, names, 452 mapperAllocas))) 453 return failure(); 454 455 llvm::GlobalVariable *maptypes = 456 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); 457 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( 458 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), 459 maptypes, /*Idx0=*/0, /*Idx1=*/0); 460 461 llvm::GlobalVariable *mapnames = 462 accBuilder->createOffloadMapnames(names, ".offload_mapnames"); 463 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( 464 llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), 465 mapnames, /*Idx0=*/0, /*Idx1=*/0); 466 467 accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo, 468 maptypesArg, mapnamesArg, mapperAllocas, 469 kDefaultDevice, totalNbOperand); 470 471 return success(); 472 } 473 474 namespace { 475 476 /// Implementation of the dialect interface that converts operations belonging 477 /// to the OpenACC dialect to LLVM IR. 478 class OpenACCDialectLLVMIRTranslationInterface 479 : public LLVMTranslationDialectInterface { 480 public: 481 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 482 483 /// Translates the given operation to LLVM IR using the provided IR builder 484 /// and saving the state in `moduleTranslation`. 485 LogicalResult 486 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 487 LLVM::ModuleTranslation &moduleTranslation) const final; 488 }; 489 490 } // end namespace 491 492 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR 493 /// (including OpenACC runtime calls). 494 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( 495 Operation *op, llvm::IRBuilderBase &builder, 496 LLVM::ModuleTranslation &moduleTranslation) const { 497 498 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 499 .Case([&](acc::DataOp dataOp) { 500 return convertDataOp(dataOp, builder, moduleTranslation); 501 }) 502 .Case([&](acc::EnterDataOp enterDataOp) { 503 return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder, 504 moduleTranslation); 505 }) 506 .Case([&](acc::ExitDataOp exitDataOp) { 507 return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder, 508 moduleTranslation); 509 }) 510 .Case([&](acc::UpdateOp updateOp) { 511 return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder, 512 moduleTranslation); 513 }) 514 .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) { 515 // `yield` and `terminator` can be just omitted. The block structure was 516 // created in the function that handles their parent operation. 517 assert(op->getNumOperands() == 0 && 518 "unexpected OpenACC terminator with operands"); 519 return success(); 520 }) 521 .Default([&](Operation *op) { 522 return op->emitError("unsupported OpenACC operation: ") 523 << op->getName(); 524 }); 525 } 526 527 void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { 528 registry.insert<acc::OpenACCDialect>(); 529 registry.addDialectInterface<acc::OpenACCDialect, 530 OpenACCDialectLLVMIRTranslationInterface>(); 531 } 532 533 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { 534 DialectRegistry registry; 535 registerOpenACCDialectTranslation(registry); 536 context.appendDialectRegistry(registry); 537 } 538