1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenACC dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" 15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/OpenACC/OpenACC.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 22 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder; 30 31 //===----------------------------------------------------------------------===// 32 // Utility functions 33 //===----------------------------------------------------------------------===// 34 35 /// 0 = alloc/create 36 static constexpr uint64_t createFlag = 0; 37 /// 1 = to/copyin 38 static constexpr uint64_t copyinFlag = 1; 39 /// 2 = from/copyout 40 static constexpr uint64_t copyoutFlag = 2; 41 /// 8 = delete 42 static constexpr uint64_t deleteFlag = 8; 43 44 /// Default value for the device id 45 static constexpr int64_t defaultDevice = -1; 46 47 /// Create a constant string location from the MLIR Location information. 48 static llvm::Constant *createSourceLocStrFromLocation(Location loc, 49 OpenACCIRBuilder &builder, 50 StringRef name) { 51 if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) { 52 StringRef fileName = fileLoc.getFilename(); 53 unsigned lineNo = fileLoc.getLine(); 54 unsigned colNo = fileLoc.getColumn(); 55 return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); 56 } else { 57 std::string locStr; 58 llvm::raw_string_ostream locOS(locStr); 59 locOS << loc; 60 return builder.getOrCreateSrcLocStr(locOS.str()); 61 } 62 } 63 64 /// Create the location struct from the operation location information. 65 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder, 66 Operation *op) { 67 auto loc = op->getLoc(); 68 auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); 69 StringRef funcName = funcOp ? funcOp.getName() : "unknown"; 70 llvm::Constant *locStr = 71 createSourceLocStrFromLocation(loc, builder, funcName); 72 return builder.getOrCreateIdent(locStr); 73 } 74 75 /// Create a constant string representing the mapping information extracted from 76 /// the MLIR location information. 77 static llvm::Constant *createMappingInformation(Location loc, 78 OpenACCIRBuilder &builder) { 79 if (auto nameLoc = loc.dyn_cast<NameLoc>()) { 80 StringRef name = nameLoc.getName(); 81 return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); 82 } else { 83 return createSourceLocStrFromLocation(loc, builder, "unknown"); 84 } 85 } 86 87 /// Return the runtime function used to lower the given operation. 88 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, 89 Operation *op) { 90 return llvm::TypeSwitch<Operation *, llvm::Function *>(op) 91 .Case([&](acc::EnterDataOp) { 92 return builder.getOrCreateRuntimeFunctionPtr( 93 llvm::omp::OMPRTL___tgt_target_data_begin_mapper); 94 }) 95 .Case([&](acc::ExitDataOp) { 96 return builder.getOrCreateRuntimeFunctionPtr( 97 llvm::omp::OMPRTL___tgt_target_data_end_mapper); 98 }); 99 llvm_unreachable("Unknown OpenACC operation"); 100 } 101 102 /// Computes the size of type in bytes. 103 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, 104 llvm::Value *basePtr) { 105 llvm::LLVMContext &ctx = builder.getContext(); 106 llvm::Value *null = 107 llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); 108 llvm::Value *sizeGep = 109 builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); 110 llvm::Value *sizePtrToInt = 111 builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); 112 return sizePtrToInt; 113 } 114 115 /// Extract pointer, size and mapping information from operands 116 /// to populate the future functions arguments. 117 static LogicalResult 118 processOperands(llvm::IRBuilderBase &builder, 119 LLVM::ModuleTranslation &moduleTranslation, Operation *op, 120 ValueRange operands, unsigned totalNbOperand, 121 uint64_t operandFlag, SmallVector<uint64_t> &flags, 122 SmallVector<llvm::Constant *> &names, unsigned &index, 123 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 124 llvm::AllocaInst *argSizes) { 125 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 126 llvm::LLVMContext &ctx = builder.getContext(); 127 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 128 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 129 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 130 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 131 132 for (Value data : operands) { 133 llvm::Value *dataValue = moduleTranslation.lookupValue(data); 134 135 llvm::Value *dataPtrBase; 136 llvm::Value *dataPtr; 137 llvm::Value *dataSize; 138 139 // Handle operands that were converted to DataDescriptor. 140 if (DataDescriptor::isValid(data)) { 141 dataPtrBase = 142 builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor); 143 dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); 144 dataSize = 145 builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); 146 } else if (data.getType().isa<LLVM::LLVMPointerType>()) { 147 dataPtrBase = dataValue; 148 dataPtr = dataValue; 149 dataSize = getSizeInBytes(builder, dataValue); 150 } else { 151 return op->emitOpError() 152 << "Data operand must be legalized before translation." 153 << "Unsupported type: " << data.getType(); 154 } 155 156 // Store base pointer extracted from operand into the i-th position of 157 // argBase. 158 llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( 159 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)}); 160 llvm::Value *ptrBaseCast = builder.CreateBitCast( 161 ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); 162 builder.CreateStore(dataPtrBase, ptrBaseCast); 163 164 // Store pointer extracted from operand into the i-th position of args. 165 llvm::Value *ptrGEP = builder.CreateInBoundsGEP( 166 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)}); 167 llvm::Value *ptrCast = 168 builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); 169 builder.CreateStore(dataPtr, ptrCast); 170 171 // Store size extracted from operand into the i-th position of argSizes. 172 llvm::Value *sizeGEP = builder.CreateInBoundsGEP( 173 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)}); 174 builder.CreateStore(dataSize, sizeGEP); 175 176 flags.push_back(operandFlag); 177 llvm::Constant *mapName = 178 createMappingInformation(data.getLoc(), *accBuilder); 179 names.push_back(mapName); 180 ++index; 181 } 182 return success(); 183 } 184 185 /// Process data operands from acc::EnterDataOp 186 static LogicalResult 187 processDataOperands(llvm::IRBuilderBase &builder, 188 LLVM::ModuleTranslation &moduleTranslation, 189 acc::EnterDataOp op, SmallVector<uint64_t> &flags, 190 SmallVector<llvm::Constant *> &names, unsigned &index, 191 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 192 llvm::AllocaInst *argSizes) { 193 // TODO add `create_zero` and `attach` operands 194 195 // Create operands are handled as `alloc` call. 196 if (failed(processOperands(builder, moduleTranslation, op, 197 op.createOperands(), op.getNumDataOperands(), 198 createFlag, flags, names, index, argsBase, args, 199 argSizes))) 200 return failure(); 201 202 // Copyin operands are handled as `to` call. 203 if (failed(processOperands(builder, moduleTranslation, op, 204 op.copyinOperands(), op.getNumDataOperands(), 205 copyinFlag, flags, names, index, argsBase, args, 206 argSizes))) 207 return failure(); 208 209 return success(); 210 } 211 212 /// Process data operands from acc::ExitDataOp 213 static LogicalResult 214 processDataOperands(llvm::IRBuilderBase &builder, 215 LLVM::ModuleTranslation &moduleTranslation, 216 acc::ExitDataOp op, SmallVector<uint64_t> &flags, 217 SmallVector<llvm::Constant *> &names, unsigned &index, 218 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 219 llvm::AllocaInst *argSizes) { 220 // TODO add `detach` operands 221 222 // Delete operands are handled as `delete` call. 223 if (failed(processOperands(builder, moduleTranslation, op, 224 op.deleteOperands(), op.getNumDataOperands(), 225 deleteFlag, flags, names, index, argsBase, args, 226 argSizes))) 227 return failure(); 228 229 // Copyout operands are handled as `from` call. 230 if (failed(processOperands(builder, moduleTranslation, op, 231 op.copyoutOperands(), op.getNumDataOperands(), 232 copyoutFlag, flags, names, index, argsBase, args, 233 argSizes))) 234 return failure(); 235 236 return success(); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // Conversion functions 241 //===----------------------------------------------------------------------===// 242 243 /// Converts an OpenACC standalone data operation into LLVM IR. 244 template <typename OpTy> 245 static LogicalResult 246 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder, 247 LLVM::ModuleTranslation &moduleTranslation) { 248 auto enclosingFuncOp = 249 op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>(); 250 llvm::Function *enclosingFunction = 251 moduleTranslation.lookupFunction(enclosingFuncOp.getName()); 252 253 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 254 255 auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op); 256 auto *mapperFunc = getAssociatedFunction(*accBuilder, op); 257 258 // Number of arguments in the enter_data operation. 259 unsigned totalNbOperand = op.getNumDataOperands(); 260 261 // TODO could be moved to OpenXXIRBuilder? 262 llvm::LLVMContext &ctx = builder.getContext(); 263 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 264 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 265 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 266 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 267 llvm::IRBuilder<>::InsertPoint allocaIP( 268 &enclosingFunction->getEntryBlock(), 269 enclosingFunction->getEntryBlock().getFirstInsertionPt()); 270 llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP(); 271 builder.restoreIP(allocaIP); 272 llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy); 273 llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy); 274 llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty); 275 builder.restoreIP(currentIP); 276 277 SmallVector<uint64_t> flags; 278 SmallVector<llvm::Constant *> names; 279 unsigned index = 0; 280 281 if (failed(processDataOperands(builder, moduleTranslation, op, flags, names, 282 index, argsBase, args, argSizes))) 283 return failure(); 284 285 llvm::GlobalVariable *maptypes = 286 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); 287 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( 288 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), 289 maptypes, /*Idx0=*/0, /*Idx1=*/0); 290 291 llvm::GlobalVariable *mapnames = 292 accBuilder->createOffloadMapnames(names, ".offload_mapnames"); 293 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( 294 llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), 295 mapnames, /*Idx0=*/0, /*Idx1=*/0); 296 297 llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP( 298 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)}); 299 llvm::Value *argsGEP = builder.CreateInBoundsGEP( 300 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)}); 301 llvm::Value *argSizesGEP = builder.CreateInBoundsGEP( 302 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)}); 303 llvm::Value *nullPtr = llvm::Constant::getNullValue( 304 llvm::Type::getInt8PtrTy(ctx)->getPointerTo()); 305 306 builder.CreateCall(mapperFunc, 307 {srcLocInfo, builder.getInt64(defaultDevice), 308 builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP, 309 argSizesGEP, maptypesArg, mapnamesArg, nullPtr}); 310 311 return success(); 312 } 313 314 namespace { 315 316 /// Implementation of the dialect interface that converts operations belonging 317 /// to the OpenACC dialect to LLVM IR. 318 class OpenACCDialectLLVMIRTranslationInterface 319 : public LLVMTranslationDialectInterface { 320 public: 321 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 322 323 /// Translates the given operation to LLVM IR using the provided IR builder 324 /// and saving the state in `moduleTranslation`. 325 LogicalResult 326 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 327 LLVM::ModuleTranslation &moduleTranslation) const final; 328 }; 329 330 } // end namespace 331 332 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR 333 /// (including OpenACC runtime calls). 334 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( 335 Operation *op, llvm::IRBuilderBase &builder, 336 LLVM::ModuleTranslation &moduleTranslation) const { 337 338 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 339 .Case([&](acc::EnterDataOp enterDataOp) { 340 return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder, 341 moduleTranslation); 342 }) 343 .Case([&](acc::ExitDataOp exitDataOp) { 344 return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder, 345 moduleTranslation); 346 }) 347 .Default([&](Operation *op) { 348 return op->emitError("unsupported OpenACC operation: ") 349 << op->getName(); 350 }); 351 } 352 353 void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { 354 registry.insert<acc::OpenACCDialect>(); 355 registry.addDialectInterface<acc::OpenACCDialect, 356 OpenACCDialectLLVMIRTranslationInterface>(); 357 } 358 359 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { 360 DialectRegistry registry; 361 registerOpenACCDialectTranslation(registry); 362 context.appendDialectRegistry(registry); 363 } 364