1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenACC dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" 15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/OpenACC/OpenACC.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 22 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder; 30 31 //===----------------------------------------------------------------------===// 32 // Utility functions 33 //===----------------------------------------------------------------------===// 34 35 /// 0 = alloc/create 36 static constexpr uint64_t kCreateFlag = 0; 37 /// 1 = to/device/copyin 38 static constexpr uint64_t kDeviceCopyinFlag = 1; 39 /// 2 = from/copyout 40 static constexpr uint64_t kHostCopyoutFlag = 2; 41 /// 8 = delete 42 static constexpr uint64_t kDeleteFlag = 8; 43 44 /// Default value for the device id 45 static constexpr int64_t kDefaultDevice = -1; 46 47 /// Create a constant string location from the MLIR Location information. 48 static llvm::Constant *createSourceLocStrFromLocation(Location loc, 49 OpenACCIRBuilder &builder, 50 StringRef name) { 51 if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) { 52 StringRef fileName = fileLoc.getFilename(); 53 unsigned lineNo = fileLoc.getLine(); 54 unsigned colNo = fileLoc.getColumn(); 55 return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); 56 } else { 57 std::string locStr; 58 llvm::raw_string_ostream locOS(locStr); 59 locOS << loc; 60 return builder.getOrCreateSrcLocStr(locOS.str()); 61 } 62 } 63 64 /// Create the location struct from the operation location information. 65 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, 66 Operation *op) { 67 auto loc = op->getLoc(); 68 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); 69 StringRef funcName = funcOp ? funcOp.getName() : "unknown"; 70 llvm::Constant *locStr = 71 createSourceLocStrFromLocation(loc, builder, funcName); 72 return builder.getOrCreateIdent(locStr); 73 } 74 75 /// Create a constant string representing the mapping information extracted from 76 /// the MLIR location information. 77 static llvm::Constant *createMappingInformation(Location loc, 78 OpenACCIRBuilder &builder) { 79 if (auto nameLoc = loc.dyn_cast<NameLoc>()) { 80 StringRef name = nameLoc.getName(); 81 return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); 82 } else { 83 return createSourceLocStrFromLocation(loc, builder, "unknown"); 84 } 85 } 86 87 /// Return the runtime function used to lower the given operation. 88 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, 89 Operation *op) { 90 return llvm::TypeSwitch<Operation *, llvm::Function *>(op) 91 .Case([&](acc::EnterDataOp) { 92 return builder.getOrCreateRuntimeFunctionPtr( 93 llvm::omp::OMPRTL___tgt_target_data_begin_mapper); 94 }) 95 .Case([&](acc::ExitDataOp) { 96 return builder.getOrCreateRuntimeFunctionPtr( 97 llvm::omp::OMPRTL___tgt_target_data_end_mapper); 98 }) 99 .Case([&](acc::UpdateOp) { 100 return builder.getOrCreateRuntimeFunctionPtr( 101 llvm::omp::OMPRTL___tgt_target_data_update_mapper); 102 }); 103 llvm_unreachable("Unknown OpenACC operation"); 104 } 105 106 /// Computes the size of type in bytes. 107 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, 108 llvm::Value *basePtr) { 109 llvm::LLVMContext &ctx = builder.getContext(); 110 llvm::Value *null = 111 llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); 112 llvm::Value *sizeGep = 113 builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); 114 llvm::Value *sizePtrToInt = 115 builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); 116 return sizePtrToInt; 117 } 118 119 /// Extract pointer, size and mapping information from operands 120 /// to populate the future functions arguments. 121 static LogicalResult 122 processOperands(llvm::IRBuilderBase &builder, 123 LLVM::ModuleTranslation &moduleTranslation, Operation *op, 124 ValueRange operands, unsigned totalNbOperand, 125 uint64_t operandFlag, SmallVector<uint64_t> &flags, 126 SmallVector<llvm::Constant *> &names, unsigned &index, 127 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 128 llvm::AllocaInst *argSizes) { 129 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 130 llvm::LLVMContext &ctx = builder.getContext(); 131 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 132 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 133 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 134 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 135 136 for (Value data : operands) { 137 llvm::Value *dataValue = moduleTranslation.lookupValue(data); 138 139 llvm::Value *dataPtrBase; 140 llvm::Value *dataPtr; 141 llvm::Value *dataSize; 142 143 // Handle operands that were converted to DataDescriptor. 144 if (DataDescriptor::isValid(data)) { 145 dataPtrBase = 146 builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor); 147 dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); 148 dataSize = 149 builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); 150 } else if (data.getType().isa<LLVM::LLVMPointerType>()) { 151 dataPtrBase = dataValue; 152 dataPtr = dataValue; 153 dataSize = getSizeInBytes(builder, dataValue); 154 } else { 155 return op->emitOpError() 156 << "Data operand must be legalized before translation." 157 << "Unsupported type: " << data.getType(); 158 } 159 160 // Store base pointer extracted from operand into the i-th position of 161 // argBase. 162 llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( 163 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)}); 164 llvm::Value *ptrBaseCast = builder.CreateBitCast( 165 ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); 166 builder.CreateStore(dataPtrBase, ptrBaseCast); 167 168 // Store pointer extracted from operand into the i-th position of args. 169 llvm::Value *ptrGEP = builder.CreateInBoundsGEP( 170 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)}); 171 llvm::Value *ptrCast = 172 builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); 173 builder.CreateStore(dataPtr, ptrCast); 174 175 // Store size extracted from operand into the i-th position of argSizes. 176 llvm::Value *sizeGEP = builder.CreateInBoundsGEP( 177 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)}); 178 builder.CreateStore(dataSize, sizeGEP); 179 180 flags.push_back(operandFlag); 181 llvm::Constant *mapName = 182 createMappingInformation(data.getLoc(), *accBuilder); 183 names.push_back(mapName); 184 ++index; 185 } 186 return success(); 187 } 188 189 /// Process data operands from acc::EnterDataOp 190 static LogicalResult processDataOperands( 191 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, 192 acc::EnterDataOp op, SmallVector<uint64_t> &flags, 193 SmallVector<llvm::Constant *> &names, llvm::AllocaInst *argsBase, 194 llvm::AllocaInst *args, llvm::AllocaInst *argSizes) { 195 // TODO add `create_zero` and `attach` operands 196 197 unsigned index = 0; 198 199 // Create operands are handled as `alloc` call. 200 if (failed(processOperands(builder, moduleTranslation, op, 201 op.createOperands(), op.getNumDataOperands(), 202 kCreateFlag, flags, names, index, argsBase, args, 203 argSizes))) 204 return failure(); 205 206 // Copyin operands are handled as `to` call. 207 if (failed(processOperands(builder, moduleTranslation, op, 208 op.copyinOperands(), op.getNumDataOperands(), 209 kDeviceCopyinFlag, flags, names, index, argsBase, 210 args, argSizes))) 211 return failure(); 212 213 return success(); 214 } 215 216 /// Process data operands from acc::ExitDataOp 217 static LogicalResult processDataOperands( 218 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, 219 acc::ExitDataOp op, SmallVector<uint64_t> &flags, 220 SmallVector<llvm::Constant *> &names, llvm::AllocaInst *argsBase, 221 llvm::AllocaInst *args, llvm::AllocaInst *argSizes) { 222 // TODO add `detach` operands 223 224 unsigned index = 0; 225 226 // Delete operands are handled as `delete` call. 227 if (failed(processOperands(builder, moduleTranslation, op, 228 op.deleteOperands(), op.getNumDataOperands(), 229 kDeleteFlag, flags, names, index, argsBase, args, 230 argSizes))) 231 return failure(); 232 233 // Copyout operands are handled as `from` call. 234 if (failed(processOperands(builder, moduleTranslation, op, 235 op.copyoutOperands(), op.getNumDataOperands(), 236 kHostCopyoutFlag, flags, names, index, argsBase, 237 args, argSizes))) 238 return failure(); 239 240 return success(); 241 } 242 243 /// Process data operands from acc::UpdateOp 244 static LogicalResult processDataOperands( 245 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, 246 acc::UpdateOp op, SmallVector<uint64_t> &flags, 247 SmallVector<llvm::Constant *> &names, llvm::AllocaInst *argsBase, 248 llvm::AllocaInst *args, llvm::AllocaInst *argSizes) { 249 unsigned index = 0; 250 251 // Host operands are handled as `from` call. 252 if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(), 253 op.getNumDataOperands(), kHostCopyoutFlag, flags, 254 names, index, argsBase, args, argSizes))) 255 return failure(); 256 257 // Device operands are handled as `to` call. 258 if (failed(processOperands(builder, moduleTranslation, op, 259 op.deviceOperands(), op.getNumDataOperands(), 260 kDeviceCopyinFlag, flags, names, index, argsBase, 261 args, argSizes))) 262 return failure(); 263 264 return success(); 265 } 266 267 //===----------------------------------------------------------------------===// 268 // Conversion functions 269 //===----------------------------------------------------------------------===// 270 271 /// Converts an OpenACC standalone data operation into LLVM IR. 272 template <typename OpTy> 273 static LogicalResult 274 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, 275 LLVM::ModuleTranslation &moduleTranslation) { 276 auto enclosingFuncOp = 277 op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>(); 278 llvm::Function *enclosingFunction = 279 moduleTranslation.lookupFunction(enclosingFuncOp.getName()); 280 281 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 282 283 auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); 284 auto *mapperFunc = getAssociatedFunction(*accBuilder, op); 285 286 // Number of arguments in the enter_data operation. 287 unsigned totalNbOperand = op.getNumDataOperands(); 288 289 // TODO could be moved to OpenXXIRBuilder? 290 llvm::LLVMContext &ctx = builder.getContext(); 291 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 292 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 293 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 294 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 295 llvm::IRBuilder<>::InsertPoint allocaIP( 296 &enclosingFunction->getEntryBlock(), 297 enclosingFunction->getEntryBlock().getFirstInsertionPt()); 298 llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP(); 299 builder.restoreIP(allocaIP); 300 llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy); 301 llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy); 302 llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty); 303 builder.restoreIP(currentIP); 304 305 SmallVector<uint64_t> flags; 306 SmallVector<llvm::Constant *> names; 307 308 if (failed(processDataOperands(builder, moduleTranslation, op, flags, names, 309 argsBase, args, argSizes))) 310 return failure(); 311 312 llvm::GlobalVariable *maptypes = 313 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); 314 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( 315 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), 316 maptypes, /*Idx0=*/0, /*Idx1=*/0); 317 318 llvm::GlobalVariable *mapnames = 319 accBuilder->createOffloadMapnames(names, ".offload_mapnames"); 320 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( 321 llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), 322 mapnames, /*Idx0=*/0, /*Idx1=*/0); 323 324 llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP( 325 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)}); 326 llvm::Value *argsGEP = builder.CreateInBoundsGEP( 327 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)}); 328 llvm::Value *argSizesGEP = builder.CreateInBoundsGEP( 329 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)}); 330 llvm::Value *nullPtr = llvm::Constant::getNullValue( 331 llvm::Type::getInt8PtrTy(ctx)->getPointerTo()); 332 333 builder.CreateCall(mapperFunc, 334 {srcLocInfo, builder.getInt64(kDefaultDevice), 335 builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP, 336 argSizesGEP, maptypesArg, mapnamesArg, nullPtr}); 337 338 return success(); 339 } 340 341 namespace { 342 343 /// Implementation of the dialect interface that converts operations belonging 344 /// to the OpenACC dialect to LLVM IR. 345 class OpenACCDialectLLVMIRTranslationInterface 346 : public LLVMTranslationDialectInterface { 347 public: 348 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 349 350 /// Translates the given operation to LLVM IR using the provided IR builder 351 /// and saving the state in `moduleTranslation`. 352 LogicalResult 353 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 354 LLVM::ModuleTranslation &moduleTranslation) const final; 355 }; 356 357 } // end namespace 358 359 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR 360 /// (including OpenACC runtime calls). 361 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( 362 Operation *op, llvm::IRBuilderBase &builder, 363 LLVM::ModuleTranslation &moduleTranslation) const { 364 365 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 366 .Case([&](acc::EnterDataOp enterDataOp) { 367 return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder, 368 moduleTranslation); 369 }) 370 .Case([&](acc::ExitDataOp exitDataOp) { 371 return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder, 372 moduleTranslation); 373 }) 374 .Case([&](acc::UpdateOp updateOp) { 375 return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder, 376 moduleTranslation); 377 }) 378 .Default([&](Operation *op) { 379 return op->emitError("unsupported OpenACC operation: ") 380 << op->getName(); 381 }); 382 } 383 384 void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { 385 registry.insert<acc::OpenACCDialect>(); 386 registry.addDialectInterface<acc::OpenACCDialect, 387 OpenACCDialectLLVMIRTranslationInterface>(); 388 } 389 390 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { 391 DialectRegistry registry; 392 registerOpenACCDialectTranslation(registry); 393 context.appendDialectRegistry(registry); 394 } 395