1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenACC dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" 15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/OpenACC/OpenACC.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 22 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder; 30 31 //===----------------------------------------------------------------------===// 32 // Utility functions 33 //===----------------------------------------------------------------------===// 34 35 /// 0 = alloc/create 36 static constexpr uint64_t kCreateFlag = 0; 37 /// 1 = to/device/copyin 38 static constexpr uint64_t kDeviceCopyinFlag = 1; 39 /// 2 = from/copyout 40 static constexpr uint64_t kHostCopyoutFlag = 2; 41 /// 8 = delete 42 static constexpr uint64_t kDeleteFlag = 8; 43 44 /// Default value for the device id 45 static constexpr int64_t kDefaultDevice = -1; 46 47 /// Create a constant string location from the MLIR Location information. 48 static llvm::Constant *createSourceLocStrFromLocation(Location loc, 49 OpenACCIRBuilder &builder, 50 StringRef name) { 51 if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) { 52 StringRef fileName = fileLoc.getFilename(); 53 unsigned lineNo = fileLoc.getLine(); 54 unsigned colNo = fileLoc.getColumn(); 55 return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); 56 } else { 57 std::string locStr; 58 llvm::raw_string_ostream locOS(locStr); 59 locOS << loc; 60 return builder.getOrCreateSrcLocStr(locOS.str()); 61 } 62 } 63 64 /// Create the location struct from the operation location information. 65 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, 66 Operation *op) { 67 auto loc = op->getLoc(); 68 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); 69 StringRef funcName = funcOp ? funcOp.getName() : "unknown"; 70 llvm::Constant *locStr = 71 createSourceLocStrFromLocation(loc, builder, funcName); 72 return builder.getOrCreateIdent(locStr); 73 } 74 75 /// Create a constant string representing the mapping information extracted from 76 /// the MLIR location information. 77 static llvm::Constant *createMappingInformation(Location loc, 78 OpenACCIRBuilder &builder) { 79 if (auto nameLoc = loc.dyn_cast<NameLoc>()) { 80 StringRef name = nameLoc.getName(); 81 return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); 82 } else { 83 return createSourceLocStrFromLocation(loc, builder, "unknown"); 84 } 85 } 86 87 /// Return the runtime function used to lower the given operation. 88 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, 89 Operation *op) { 90 return llvm::TypeSwitch<Operation *, llvm::Function *>(op) 91 .Case([&](acc::EnterDataOp) { 92 return builder.getOrCreateRuntimeFunctionPtr( 93 llvm::omp::OMPRTL___tgt_target_data_begin_mapper); 94 }) 95 .Case([&](acc::ExitDataOp) { 96 return builder.getOrCreateRuntimeFunctionPtr( 97 llvm::omp::OMPRTL___tgt_target_data_end_mapper); 98 }) 99 .Case([&](acc::UpdateOp) { 100 return builder.getOrCreateRuntimeFunctionPtr( 101 llvm::omp::OMPRTL___tgt_target_data_update_mapper); 102 }); 103 llvm_unreachable("Unknown OpenACC operation"); 104 } 105 106 /// Computes the size of type in bytes. 107 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, 108 llvm::Value *basePtr) { 109 llvm::LLVMContext &ctx = builder.getContext(); 110 llvm::Value *null = 111 llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); 112 llvm::Value *sizeGep = 113 builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); 114 llvm::Value *sizePtrToInt = 115 builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); 116 return sizePtrToInt; 117 } 118 119 /// Extract pointer, size and mapping information from operands 120 /// to populate the future functions arguments. 121 static LogicalResult 122 processOperands(llvm::IRBuilderBase &builder, 123 LLVM::ModuleTranslation &moduleTranslation, Operation *op, 124 ValueRange operands, unsigned totalNbOperand, 125 uint64_t operandFlag, SmallVector<uint64_t> &flags, 126 SmallVector<llvm::Constant *> &names, unsigned &index, 127 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 128 llvm::AllocaInst *argSizes) { 129 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 130 llvm::LLVMContext &ctx = builder.getContext(); 131 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 132 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 133 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 134 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 135 136 for (Value data : operands) { 137 llvm::Value *dataValue = moduleTranslation.lookupValue(data); 138 139 llvm::Value *dataPtrBase; 140 llvm::Value *dataPtr; 141 llvm::Value *dataSize; 142 143 // Handle operands that were converted to DataDescriptor. 144 if (DataDescriptor::isValid(data)) { 145 dataPtrBase = 146 builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor); 147 dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); 148 dataSize = 149 builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); 150 } else if (data.getType().isa<LLVM::LLVMPointerType>()) { 151 dataPtrBase = dataValue; 152 dataPtr = dataValue; 153 dataSize = getSizeInBytes(builder, dataValue); 154 } else { 155 return op->emitOpError() 156 << "Data operand must be legalized before translation." 157 << "Unsupported type: " << data.getType(); 158 } 159 160 // Store base pointer extracted from operand into the i-th position of 161 // argBase. 162 llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( 163 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)}); 164 llvm::Value *ptrBaseCast = builder.CreateBitCast( 165 ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); 166 builder.CreateStore(dataPtrBase, ptrBaseCast); 167 168 // Store pointer extracted from operand into the i-th position of args. 169 llvm::Value *ptrGEP = builder.CreateInBoundsGEP( 170 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)}); 171 llvm::Value *ptrCast = 172 builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); 173 builder.CreateStore(dataPtr, ptrCast); 174 175 // Store size extracted from operand into the i-th position of argSizes. 176 llvm::Value *sizeGEP = builder.CreateInBoundsGEP( 177 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)}); 178 builder.CreateStore(dataSize, sizeGEP); 179 180 flags.push_back(operandFlag); 181 llvm::Constant *mapName = 182 createMappingInformation(data.getLoc(), *accBuilder); 183 names.push_back(mapName); 184 ++index; 185 } 186 return success(); 187 } 188 189 /// Process data operands from acc::EnterDataOp 190 static LogicalResult 191 processDataOperands(llvm::IRBuilderBase &builder, 192 LLVM::ModuleTranslation &moduleTranslation, 193 acc::EnterDataOp op, SmallVector<uint64_t> &flags, 194 SmallVector<llvm::Constant *> &names, unsigned &index, 195 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 196 llvm::AllocaInst *argSizes) { 197 // TODO add `create_zero` and `attach` operands 198 199 // Create operands are handled as `alloc` call. 200 if (failed(processOperands(builder, moduleTranslation, op, 201 op.createOperands(), op.getNumDataOperands(), 202 kCreateFlag, flags, names, index, argsBase, args, 203 argSizes))) 204 return failure(); 205 206 // Copyin operands are handled as `to` call. 207 if (failed(processOperands(builder, moduleTranslation, op, 208 op.copyinOperands(), op.getNumDataOperands(), 209 kDeviceCopyinFlag, flags, names, index, argsBase, 210 args, argSizes))) 211 return failure(); 212 213 return success(); 214 } 215 216 /// Process data operands from acc::ExitDataOp 217 static LogicalResult 218 processDataOperands(llvm::IRBuilderBase &builder, 219 LLVM::ModuleTranslation &moduleTranslation, 220 acc::ExitDataOp op, SmallVector<uint64_t> &flags, 221 SmallVector<llvm::Constant *> &names, unsigned &index, 222 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 223 llvm::AllocaInst *argSizes) { 224 // TODO add `detach` operands 225 226 // Delete operands are handled as `delete` call. 227 if (failed(processOperands(builder, moduleTranslation, op, 228 op.deleteOperands(), op.getNumDataOperands(), 229 kDeleteFlag, flags, names, index, argsBase, args, 230 argSizes))) 231 return failure(); 232 233 // Copyout operands are handled as `from` call. 234 if (failed(processOperands(builder, moduleTranslation, op, 235 op.copyoutOperands(), op.getNumDataOperands(), 236 kHostCopyoutFlag, flags, names, index, argsBase, 237 args, argSizes))) 238 return failure(); 239 240 return success(); 241 } 242 243 /// Process data operands from acc::UpdateOp 244 static LogicalResult 245 processDataOperands(llvm::IRBuilderBase &builder, 246 LLVM::ModuleTranslation &moduleTranslation, 247 acc::UpdateOp op, SmallVector<uint64_t> &flags, 248 SmallVector<llvm::Constant *> &names, unsigned &index, 249 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 250 llvm::AllocaInst *argSizes) { 251 252 // Host operands are handled as `from` call. 253 if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(), 254 op.getNumDataOperands(), kHostCopyoutFlag, flags, 255 names, index, argsBase, args, argSizes))) 256 return failure(); 257 258 // Device operands are handled as `to` call. 259 if (failed(processOperands(builder, moduleTranslation, op, 260 op.deviceOperands(), op.getNumDataOperands(), 261 kDeviceCopyinFlag, flags, names, index, argsBase, 262 args, argSizes))) 263 return failure(); 264 265 return success(); 266 } 267 268 //===----------------------------------------------------------------------===// 269 // Conversion functions 270 //===----------------------------------------------------------------------===// 271 272 /// Converts an OpenACC standalone data operation into LLVM IR. 273 template <typename OpTy> 274 static LogicalResult 275 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, 276 LLVM::ModuleTranslation &moduleTranslation) { 277 auto enclosingFuncOp = 278 op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>(); 279 llvm::Function *enclosingFunction = 280 moduleTranslation.lookupFunction(enclosingFuncOp.getName()); 281 282 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 283 284 auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); 285 auto *mapperFunc = getAssociatedFunction(*accBuilder, op); 286 287 // Number of arguments in the enter_data operation. 288 unsigned totalNbOperand = op.getNumDataOperands(); 289 290 // TODO could be moved to OpenXXIRBuilder? 291 llvm::LLVMContext &ctx = builder.getContext(); 292 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 293 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 294 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 295 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 296 llvm::IRBuilder<>::InsertPoint allocaIP( 297 &enclosingFunction->getEntryBlock(), 298 enclosingFunction->getEntryBlock().getFirstInsertionPt()); 299 llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP(); 300 builder.restoreIP(allocaIP); 301 llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy); 302 llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy); 303 llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty); 304 builder.restoreIP(currentIP); 305 306 SmallVector<uint64_t> flags; 307 SmallVector<llvm::Constant *> names; 308 unsigned index = 0; 309 310 if (failed(processDataOperands(builder, moduleTranslation, op, flags, names, 311 index, argsBase, args, argSizes))) 312 return failure(); 313 314 llvm::GlobalVariable *maptypes = 315 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); 316 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( 317 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), 318 maptypes, /*Idx0=*/0, /*Idx1=*/0); 319 320 llvm::GlobalVariable *mapnames = 321 accBuilder->createOffloadMapnames(names, ".offload_mapnames"); 322 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( 323 llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), 324 mapnames, /*Idx0=*/0, /*Idx1=*/0); 325 326 llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP( 327 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)}); 328 llvm::Value *argsGEP = builder.CreateInBoundsGEP( 329 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)}); 330 llvm::Value *argSizesGEP = builder.CreateInBoundsGEP( 331 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)}); 332 llvm::Value *nullPtr = llvm::Constant::getNullValue( 333 llvm::Type::getInt8PtrTy(ctx)->getPointerTo()); 334 335 builder.CreateCall(mapperFunc, 336 {srcLocInfo, builder.getInt64(kDefaultDevice), 337 builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP, 338 argSizesGEP, maptypesArg, mapnamesArg, nullPtr}); 339 340 return success(); 341 } 342 343 namespace { 344 345 /// Implementation of the dialect interface that converts operations belonging 346 /// to the OpenACC dialect to LLVM IR. 347 class OpenACCDialectLLVMIRTranslationInterface 348 : public LLVMTranslationDialectInterface { 349 public: 350 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 351 352 /// Translates the given operation to LLVM IR using the provided IR builder 353 /// and saving the state in `moduleTranslation`. 354 LogicalResult 355 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 356 LLVM::ModuleTranslation &moduleTranslation) const final; 357 }; 358 359 } // end namespace 360 361 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR 362 /// (including OpenACC runtime calls). 363 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( 364 Operation *op, llvm::IRBuilderBase &builder, 365 LLVM::ModuleTranslation &moduleTranslation) const { 366 367 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 368 .Case([&](acc::EnterDataOp enterDataOp) { 369 return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder, 370 moduleTranslation); 371 }) 372 .Case([&](acc::ExitDataOp exitDataOp) { 373 return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder, 374 moduleTranslation); 375 }) 376 .Case([&](acc::UpdateOp updateOp) { 377 return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder, 378 moduleTranslation); 379 }) 380 .Default([&](Operation *op) { 381 return op->emitError("unsupported OpenACC operation: ") 382 << op->getName(); 383 }); 384 } 385 386 void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { 387 registry.insert<acc::OpenACCDialect>(); 388 registry.addDialectInterface<acc::OpenACCDialect, 389 OpenACCDialectLLVMIRTranslationInterface>(); 390 } 391 392 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { 393 DialectRegistry registry; 394 registerOpenACCDialectTranslation(registry); 395 context.appendDialectRegistry(registry); 396 } 397