1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a translation between the MLIR OpenACC dialect and LLVM 10 // IR. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h" 15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" 16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 17 #include "mlir/Dialect/OpenACC/OpenACC.h" 18 #include "mlir/IR/BuiltinOps.h" 19 #include "mlir/IR/Operation.h" 20 #include "mlir/Support/LLVM.h" 21 #include "mlir/Target/LLVMIR/ModuleTranslation.h" 22 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/Frontend/OpenMP/OMPConstants.h" 25 #include "llvm/Support/FormatVariadic.h" 26 27 using namespace mlir; 28 29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder; 30 31 //===----------------------------------------------------------------------===// 32 // Utility functions 33 //===----------------------------------------------------------------------===// 34 35 /// 0 = alloc/create 36 static constexpr uint64_t createFlag = 0; 37 /// 1 = to/copyin 38 static constexpr uint64_t copyinFlag = 1; 39 /// Default value for the device id 40 static constexpr int64_t defaultDevice = -1; 41 42 /// Create a constant string location from the MLIR Location information. 43 static llvm::Constant *createSourceLocStrFromLocation(Location loc, 44 OpenACCIRBuilder &builder, 45 StringRef name) { 46 if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) { 47 StringRef fileName = fileLoc.getFilename(); 48 unsigned lineNo = fileLoc.getLine(); 49 unsigned colNo = fileLoc.getColumn(); 50 return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo); 51 } else { 52 std::string locStr; 53 llvm::raw_string_ostream locOS(locStr); 54 locOS << loc; 55 return builder.getOrCreateSrcLocStr(locOS.str()); 56 } 57 } 58 59 /// Create the location struct from the operation location information. 60 static llvm::Value *createSourceLocationInfo(acc::EnterDataOp &op, 61 OpenACCIRBuilder &builder) { 62 auto loc = op.getLoc(); 63 auto funcOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>(); 64 StringRef funcName = funcOp ? funcOp.getName() : "unknown"; 65 llvm::Constant *locStr = 66 createSourceLocStrFromLocation(loc, builder, funcName); 67 return builder.getOrCreateIdent(locStr); 68 } 69 70 /// Create a constant string representing the mapping information extracted from 71 /// the MLIR location information. 72 static llvm::Constant *createMappingInformation(Location loc, 73 OpenACCIRBuilder &builder) { 74 if (auto nameLoc = loc.dyn_cast<NameLoc>()) { 75 StringRef name = nameLoc.getName(); 76 return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name); 77 } else { 78 return createSourceLocStrFromLocation(loc, builder, "unknown"); 79 } 80 } 81 82 /// Return the runtime function used to lower the given operation. 83 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder, 84 Operation &op) { 85 if (isa<acc::EnterDataOp>(op)) 86 return builder.getOrCreateRuntimeFunctionPtr( 87 llvm::omp::OMPRTL___tgt_target_data_begin_mapper); 88 llvm_unreachable("Unknown OpenACC operation"); 89 } 90 91 /// Computes the size of type in bytes. 92 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder, 93 llvm::Value *basePtr) { 94 llvm::LLVMContext &ctx = builder.getContext(); 95 llvm::Value *null = 96 llvm::Constant::getNullValue(basePtr->getType()->getPointerTo()); 97 llvm::Value *sizeGep = 98 builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1)); 99 llvm::Value *sizePtrToInt = 100 builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx)); 101 return sizePtrToInt; 102 } 103 104 /// Extract pointer, size and mapping information from operands 105 /// to populate the future functions arguments. 106 static LogicalResult 107 processOperands(llvm::IRBuilderBase &builder, 108 LLVM::ModuleTranslation &moduleTranslation, Operation &op, 109 ValueRange operands, unsigned totalNbOperand, 110 uint64_t operandFlag, SmallVector<uint64_t> &flags, 111 SmallVector<llvm::Constant *> &names, unsigned &index, 112 llvm::AllocaInst *argsBase, llvm::AllocaInst *args, 113 llvm::AllocaInst *argSizes) { 114 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 115 llvm::LLVMContext &ctx = builder.getContext(); 116 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 117 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 118 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 119 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 120 121 for (Value data : operands) { 122 llvm::Value *dataValue = moduleTranslation.lookupValue(data); 123 124 llvm::Value *dataPtrBase; 125 llvm::Value *dataPtr; 126 llvm::Value *dataSize; 127 128 // Handle operands that were converted to DataDescriptor. 129 if (DataDescriptor::isValid(data)) { 130 dataPtrBase = 131 builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor); 132 dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor); 133 dataSize = 134 builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor); 135 } else if (data.getType().isa<LLVM::LLVMPointerType>()) { 136 dataPtrBase = dataValue; 137 dataPtr = dataValue; 138 dataSize = getSizeInBytes(builder, dataValue); 139 } else { 140 return op.emitOpError() 141 << "Data operand must be legalized before translation." 142 << "Unsupported type: " << data.getType(); 143 } 144 145 // Store base pointer extracted from operand into the i-th position of 146 // argBase. 147 llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP( 148 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)}); 149 llvm::Value *ptrBaseCast = builder.CreateBitCast( 150 ptrBaseGEP, dataPtrBase->getType()->getPointerTo()); 151 builder.CreateStore(dataPtrBase, ptrBaseCast); 152 153 // Store pointer extracted from operand into the i-th position of args. 154 llvm::Value *ptrGEP = builder.CreateInBoundsGEP( 155 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)}); 156 llvm::Value *ptrCast = 157 builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo()); 158 builder.CreateStore(dataPtr, ptrCast); 159 160 // Store size extracted from operand into the i-th position of argSizes. 161 llvm::Value *sizeGEP = builder.CreateInBoundsGEP( 162 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)}); 163 builder.CreateStore(dataSize, sizeGEP); 164 165 flags.push_back(operandFlag); 166 llvm::Constant *mapName = 167 createMappingInformation(data.getLoc(), *accBuilder); 168 names.push_back(mapName); 169 ++index; 170 } 171 return success(); 172 } 173 174 //===----------------------------------------------------------------------===// 175 // Conversion functions 176 //===----------------------------------------------------------------------===// 177 178 /// Converts an OpenACC enter_data operartion into LLVM IR. 179 static LogicalResult 180 convertEnterDataOp(Operation &op, llvm::IRBuilderBase &builder, 181 LLVM::ModuleTranslation &moduleTranslation) { 182 auto enterDataOp = cast<acc::EnterDataOp>(op); 183 auto enclosingFuncOp = op.getParentOfType<LLVM::LLVMFuncOp>(); 184 llvm::Function *enclosingFunction = 185 moduleTranslation.lookupFunction(enclosingFuncOp.getName()); 186 187 OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder(); 188 189 auto *srcLocInfo = createSourceLocationInfo(enterDataOp, *accBuilder); 190 auto *mapperFunc = getAssociatedFunction(*accBuilder, op); 191 192 // Number of arguments in the enter_data operation. 193 // TODO include create_zero and attach operands. 194 unsigned totalNbOperand = 195 enterDataOp.createOperands().size() + enterDataOp.copyinOperands().size(); 196 197 // TODO could be moved to OpenXXIRBuilder? 198 llvm::LLVMContext &ctx = builder.getContext(); 199 auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx); 200 auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand); 201 auto *i64Ty = llvm::Type::getInt64Ty(ctx); 202 auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand); 203 llvm::IRBuilder<>::InsertPoint allocaIP( 204 &enclosingFunction->getEntryBlock(), 205 enclosingFunction->getEntryBlock().getFirstInsertionPt()); 206 llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP(); 207 builder.restoreIP(allocaIP); 208 llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy); 209 llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy); 210 llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty); 211 builder.restoreIP(currentIP); 212 213 SmallVector<uint64_t> flags; 214 SmallVector<llvm::Constant *> names; 215 unsigned index = 0; 216 217 // Create operands are handled as `alloc` call. 218 if (failed(processOperands(builder, moduleTranslation, op, 219 enterDataOp.createOperands(), totalNbOperand, 220 createFlag, flags, names, index, argsBase, args, 221 argSizes))) 222 return failure(); 223 224 // Copyin operands are handled as `to` call. 225 if (failed(processOperands(builder, moduleTranslation, op, 226 enterDataOp.copyinOperands(), totalNbOperand, 227 copyinFlag, flags, names, index, argsBase, args, 228 argSizes))) 229 return failure(); 230 231 llvm::GlobalVariable *maptypes = 232 accBuilder->createOffloadMaptypes(flags, ".offload_maptypes"); 233 llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32( 234 llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand), 235 maptypes, /*Idx0=*/0, /*Idx1=*/0); 236 237 llvm::GlobalVariable *mapnames = 238 accBuilder->createOffloadMapnames(names, ".offload_mapnames"); 239 llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32( 240 llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand), 241 mapnames, /*Idx0=*/0, /*Idx1=*/0); 242 243 llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP( 244 arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)}); 245 llvm::Value *argsGEP = builder.CreateInBoundsGEP( 246 arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)}); 247 llvm::Value *argSizesGEP = builder.CreateInBoundsGEP( 248 arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)}); 249 llvm::Value *nullPtr = llvm::Constant::getNullValue( 250 llvm::Type::getInt8PtrTy(ctx)->getPointerTo()); 251 252 builder.CreateCall(mapperFunc, 253 {srcLocInfo, builder.getInt64(defaultDevice), 254 builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP, 255 argSizesGEP, maptypesArg, mapnamesArg, nullPtr}); 256 257 return success(); 258 } 259 260 namespace { 261 262 /// Implementation of the dialect interface that converts operations belonging 263 /// to the OpenACC dialect to LLVM IR. 264 class OpenACCDialectLLVMIRTranslationInterface 265 : public LLVMTranslationDialectInterface { 266 public: 267 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; 268 269 /// Translates the given operation to LLVM IR using the provided IR builder 270 /// and saving the state in `moduleTranslation`. 271 LogicalResult 272 convertOperation(Operation *op, llvm::IRBuilderBase &builder, 273 LLVM::ModuleTranslation &moduleTranslation) const final; 274 }; 275 276 } // end namespace 277 278 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR 279 /// (including OpenACC runtime calls). 280 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation( 281 Operation *op, llvm::IRBuilderBase &builder, 282 LLVM::ModuleTranslation &moduleTranslation) const { 283 284 return llvm::TypeSwitch<Operation *, LogicalResult>(op) 285 .Case([&](acc::EnterDataOp) { 286 return convertEnterDataOp(*op, builder, moduleTranslation); 287 }) 288 .Default([&](Operation *op) { 289 return op->emitError("unsupported OpenACC operation: ") 290 << op->getName(); 291 }); 292 } 293 294 void mlir::registerOpenACCDialectTranslation(DialectRegistry ®istry) { 295 registry.insert<acc::OpenACCDialect>(); 296 registry.addDialectInterface<acc::OpenACCDialect, 297 OpenACCDialectLLVMIRTranslationInterface>(); 298 } 299 300 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) { 301 DialectRegistry registry; 302 registerOpenACCDialectTranslation(registry); 303 context.appendDialectRegistry(registry); 304 } 305