1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/OpenACC/OpenACC.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// 0 = alloc/create
36 static constexpr uint64_t createFlag = 0;
37 /// 1 = to/copyin
38 static constexpr uint64_t copyinFlag = 1;
39 /// Default value for the device id
40 static constexpr int64_t defaultDevice = -1;
41 
42 /// Create a constant string location from the MLIR Location information.
43 static llvm::Constant *createSourceLocStrFromLocation(Location loc,
44                                                       OpenACCIRBuilder &builder,
45                                                       StringRef name) {
46   if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
47     StringRef fileName = fileLoc.getFilename();
48     unsigned lineNo = fileLoc.getLine();
49     unsigned colNo = fileLoc.getColumn();
50     return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
51   } else {
52     std::string locStr;
53     llvm::raw_string_ostream locOS(locStr);
54     locOS << loc;
55     return builder.getOrCreateSrcLocStr(locOS.str());
56   }
57 }
58 
59 /// Create the location struct from the operation location information.
60 static llvm::Value *createSourceLocationInfo(acc::EnterDataOp &op,
61                                              OpenACCIRBuilder &builder) {
62   auto loc = op.getLoc();
63   auto funcOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
64   StringRef funcName = funcOp ? funcOp.getName() : "unknown";
65   llvm::Constant *locStr =
66       createSourceLocStrFromLocation(loc, builder, funcName);
67   return builder.getOrCreateIdent(locStr);
68 }
69 
70 /// Create a constant string representing the mapping information extracted from
71 /// the MLIR location information.
72 static llvm::Constant *createMappingInformation(Location loc,
73                                                 OpenACCIRBuilder &builder) {
74   if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
75     StringRef name = nameLoc.getName();
76     return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
77   } else {
78     return createSourceLocStrFromLocation(loc, builder, "unknown");
79   }
80 }
81 
82 /// Return the runtime function used to lower the given operation.
83 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
84                                              Operation &op) {
85   if (isa<acc::EnterDataOp>(op))
86     return builder.getOrCreateRuntimeFunctionPtr(
87         llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
88   llvm_unreachable("Unknown OpenACC operation");
89 }
90 
91 /// Computes the size of type in bytes.
92 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder,
93                                    llvm::Value *basePtr) {
94   llvm::LLVMContext &ctx = builder.getContext();
95   llvm::Value *null =
96       llvm::Constant::getNullValue(basePtr->getType()->getPointerTo());
97   llvm::Value *sizeGep =
98       builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1));
99   llvm::Value *sizePtrToInt =
100       builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx));
101   return sizePtrToInt;
102 }
103 
104 /// Extract pointer, size and mapping information from operands
105 /// to populate the future functions arguments.
106 static LogicalResult
107 processOperands(llvm::IRBuilderBase &builder,
108                 LLVM::ModuleTranslation &moduleTranslation, Operation &op,
109                 ValueRange operands, unsigned totalNbOperand,
110                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
111                 SmallVector<llvm::Constant *> &names, unsigned &index,
112                 llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
113                 llvm::AllocaInst *argSizes) {
114   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
115   llvm::LLVMContext &ctx = builder.getContext();
116   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
117   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
118   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
119   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
120 
121   for (Value data : operands) {
122     llvm::Value *dataValue = moduleTranslation.lookupValue(data);
123 
124     llvm::Value *dataPtrBase;
125     llvm::Value *dataPtr;
126     llvm::Value *dataSize;
127 
128     // Handle operands that were converted to DataDescriptor.
129     if (DataDescriptor::isValid(data)) {
130       dataPtrBase =
131           builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor);
132       dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor);
133       dataSize =
134           builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor);
135     } else if (data.getType().isa<LLVM::LLVMPointerType>()) {
136       dataPtrBase = dataValue;
137       dataPtr = dataValue;
138       dataSize = getSizeInBytes(builder, dataValue);
139     } else {
140       return op.emitOpError()
141              << "Data operand must be legalized before translation."
142              << "Unsupported type: " << data.getType();
143     }
144 
145     // Store base pointer extracted from operand into the i-th position of
146     // argBase.
147     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
148         arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)});
149     llvm::Value *ptrBaseCast = builder.CreateBitCast(
150         ptrBaseGEP, dataPtrBase->getType()->getPointerTo());
151     builder.CreateStore(dataPtrBase, ptrBaseCast);
152 
153     // Store pointer extracted from operand into the i-th position of args.
154     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
155         arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)});
156     llvm::Value *ptrCast =
157         builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo());
158     builder.CreateStore(dataPtr, ptrCast);
159 
160     // Store size extracted from operand into the i-th position of argSizes.
161     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
162         arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)});
163     builder.CreateStore(dataSize, sizeGEP);
164 
165     flags.push_back(operandFlag);
166     llvm::Constant *mapName =
167         createMappingInformation(data.getLoc(), *accBuilder);
168     names.push_back(mapName);
169     ++index;
170   }
171   return success();
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Conversion functions
176 //===----------------------------------------------------------------------===//
177 
178 /// Converts an OpenACC enter_data operartion into LLVM IR.
179 static LogicalResult
180 convertEnterDataOp(Operation &op, llvm::IRBuilderBase &builder,
181                    LLVM::ModuleTranslation &moduleTranslation) {
182   auto enterDataOp = cast<acc::EnterDataOp>(op);
183   auto enclosingFuncOp = op.getParentOfType<LLVM::LLVMFuncOp>();
184   llvm::Function *enclosingFunction =
185       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
186 
187   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
188 
189   auto *srcLocInfo = createSourceLocationInfo(enterDataOp, *accBuilder);
190   auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
191 
192   // Number of arguments in the enter_data operation.
193   // TODO include create_zero and attach operands.
194   unsigned totalNbOperand =
195       enterDataOp.createOperands().size() + enterDataOp.copyinOperands().size();
196 
197   // TODO could be moved to OpenXXIRBuilder?
198   llvm::LLVMContext &ctx = builder.getContext();
199   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
200   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
201   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
202   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
203   llvm::IRBuilder<>::InsertPoint allocaIP(
204       &enclosingFunction->getEntryBlock(),
205       enclosingFunction->getEntryBlock().getFirstInsertionPt());
206   llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP();
207   builder.restoreIP(allocaIP);
208   llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy);
209   llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy);
210   llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty);
211   builder.restoreIP(currentIP);
212 
213   SmallVector<uint64_t> flags;
214   SmallVector<llvm::Constant *> names;
215   unsigned index = 0;
216 
217   // Create operands are handled as `alloc` call.
218   if (failed(processOperands(builder, moduleTranslation, op,
219                              enterDataOp.createOperands(), totalNbOperand,
220                              createFlag, flags, names, index, argsBase, args,
221                              argSizes)))
222     return failure();
223 
224   // Copyin operands are handled as `to` call.
225   if (failed(processOperands(builder, moduleTranslation, op,
226                              enterDataOp.copyinOperands(), totalNbOperand,
227                              copyinFlag, flags, names, index, argsBase, args,
228                              argSizes)))
229     return failure();
230 
231   llvm::GlobalVariable *maptypes =
232       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
233   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
234       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
235       maptypes, /*Idx0=*/0, /*Idx1=*/0);
236 
237   llvm::GlobalVariable *mapnames =
238       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
239   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
240       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
241       mapnames, /*Idx0=*/0, /*Idx1=*/0);
242 
243   llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP(
244       arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)});
245   llvm::Value *argsGEP = builder.CreateInBoundsGEP(
246       arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)});
247   llvm::Value *argSizesGEP = builder.CreateInBoundsGEP(
248       arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)});
249   llvm::Value *nullPtr = llvm::Constant::getNullValue(
250       llvm::Type::getInt8PtrTy(ctx)->getPointerTo());
251 
252   builder.CreateCall(mapperFunc,
253                      {srcLocInfo, builder.getInt64(defaultDevice),
254                       builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP,
255                       argSizesGEP, maptypesArg, mapnamesArg, nullPtr});
256 
257   return success();
258 }
259 
260 namespace {
261 
262 /// Implementation of the dialect interface that converts operations belonging
263 /// to the OpenACC dialect to LLVM IR.
264 class OpenACCDialectLLVMIRTranslationInterface
265     : public LLVMTranslationDialectInterface {
266 public:
267   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
268 
269   /// Translates the given operation to LLVM IR using the provided IR builder
270   /// and saving the state in `moduleTranslation`.
271   LogicalResult
272   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
273                    LLVM::ModuleTranslation &moduleTranslation) const final;
274 };
275 
276 } // end namespace
277 
278 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
279 /// (including OpenACC runtime calls).
280 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
281     Operation *op, llvm::IRBuilderBase &builder,
282     LLVM::ModuleTranslation &moduleTranslation) const {
283 
284   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
285       .Case([&](acc::EnterDataOp) {
286         return convertEnterDataOp(*op, builder, moduleTranslation);
287       })
288       .Default([&](Operation *op) {
289         return op->emitError("unsupported OpenACC operation: ")
290                << op->getName();
291       });
292 }
293 
294 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
295   registry.insert<acc::OpenACCDialect>();
296   registry.addDialectInterface<acc::OpenACCDialect,
297                                OpenACCDialectLLVMIRTranslationInterface>();
298 }
299 
300 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
301   DialectRegistry registry;
302   registerOpenACCDialectTranslation(registry);
303   context.appendDialectRegistry(registry);
304 }
305