1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/OpenACC/OpenACC.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// 0 = alloc/create
36 static constexpr uint64_t kCreateFlag = 0;
37 /// 1 = to/device/copyin
38 static constexpr uint64_t kDeviceCopyinFlag = 1;
39 /// 2 = from/copyout
40 static constexpr uint64_t kHostCopyoutFlag = 2;
41 /// 8 = delete
42 static constexpr uint64_t kDeleteFlag = 8;
43 
44 /// Default value for the device id
45 static constexpr int64_t kDefaultDevice = -1;
46 
47 /// Create a constant string location from the MLIR Location information.
48 static llvm::Constant *createSourceLocStrFromLocation(Location loc,
49                                                       OpenACCIRBuilder &builder,
50                                                       StringRef name) {
51   if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
52     StringRef fileName = fileLoc.getFilename();
53     unsigned lineNo = fileLoc.getLine();
54     unsigned colNo = fileLoc.getColumn();
55     return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
56   } else {
57     std::string locStr;
58     llvm::raw_string_ostream locOS(locStr);
59     locOS << loc;
60     return builder.getOrCreateSrcLocStr(locOS.str());
61   }
62 }
63 
64 /// Create the location struct from the operation location information.
65 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
66                                              Operation *op) {
67   auto loc = op->getLoc();
68   auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
69   StringRef funcName = funcOp ? funcOp.getName() : "unknown";
70   llvm::Constant *locStr =
71       createSourceLocStrFromLocation(loc, builder, funcName);
72   return builder.getOrCreateIdent(locStr);
73 }
74 
75 /// Create a constant string representing the mapping information extracted from
76 /// the MLIR location information.
77 static llvm::Constant *createMappingInformation(Location loc,
78                                                 OpenACCIRBuilder &builder) {
79   if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
80     StringRef name = nameLoc.getName();
81     return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
82   } else {
83     return createSourceLocStrFromLocation(loc, builder, "unknown");
84   }
85 }
86 
87 /// Return the runtime function used to lower the given operation.
88 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
89                                              Operation *op) {
90   return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
91       .Case([&](acc::EnterDataOp) {
92         return builder.getOrCreateRuntimeFunctionPtr(
93             llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
94       })
95       .Case([&](acc::ExitDataOp) {
96         return builder.getOrCreateRuntimeFunctionPtr(
97             llvm::omp::OMPRTL___tgt_target_data_end_mapper);
98       })
99       .Case([&](acc::UpdateOp) {
100         return builder.getOrCreateRuntimeFunctionPtr(
101             llvm::omp::OMPRTL___tgt_target_data_update_mapper);
102       });
103   llvm_unreachable("Unknown OpenACC operation");
104 }
105 
106 /// Computes the size of type in bytes.
107 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder,
108                                    llvm::Value *basePtr) {
109   llvm::LLVMContext &ctx = builder.getContext();
110   llvm::Value *null =
111       llvm::Constant::getNullValue(basePtr->getType()->getPointerTo());
112   llvm::Value *sizeGep =
113       builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1));
114   llvm::Value *sizePtrToInt =
115       builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx));
116   return sizePtrToInt;
117 }
118 
119 /// Extract pointer, size and mapping information from operands
120 /// to populate the future functions arguments.
121 static LogicalResult
122 processOperands(llvm::IRBuilderBase &builder,
123                 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
124                 ValueRange operands, unsigned totalNbOperand,
125                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
126                 SmallVector<llvm::Constant *> &names, unsigned &index,
127                 llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
128                 llvm::AllocaInst *argSizes) {
129   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
130   llvm::LLVMContext &ctx = builder.getContext();
131   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
132   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
133   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
134   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
135 
136   for (Value data : operands) {
137     llvm::Value *dataValue = moduleTranslation.lookupValue(data);
138 
139     llvm::Value *dataPtrBase;
140     llvm::Value *dataPtr;
141     llvm::Value *dataSize;
142 
143     // Handle operands that were converted to DataDescriptor.
144     if (DataDescriptor::isValid(data)) {
145       dataPtrBase =
146           builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor);
147       dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor);
148       dataSize =
149           builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor);
150     } else if (data.getType().isa<LLVM::LLVMPointerType>()) {
151       dataPtrBase = dataValue;
152       dataPtr = dataValue;
153       dataSize = getSizeInBytes(builder, dataValue);
154     } else {
155       return op->emitOpError()
156              << "Data operand must be legalized before translation."
157              << "Unsupported type: " << data.getType();
158     }
159 
160     // Store base pointer extracted from operand into the i-th position of
161     // argBase.
162     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
163         arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)});
164     llvm::Value *ptrBaseCast = builder.CreateBitCast(
165         ptrBaseGEP, dataPtrBase->getType()->getPointerTo());
166     builder.CreateStore(dataPtrBase, ptrBaseCast);
167 
168     // Store pointer extracted from operand into the i-th position of args.
169     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
170         arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)});
171     llvm::Value *ptrCast =
172         builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo());
173     builder.CreateStore(dataPtr, ptrCast);
174 
175     // Store size extracted from operand into the i-th position of argSizes.
176     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
177         arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)});
178     builder.CreateStore(dataSize, sizeGEP);
179 
180     flags.push_back(operandFlag);
181     llvm::Constant *mapName =
182         createMappingInformation(data.getLoc(), *accBuilder);
183     names.push_back(mapName);
184     ++index;
185   }
186   return success();
187 }
188 
189 /// Process data operands from acc::EnterDataOp
190 static LogicalResult
191 processDataOperands(llvm::IRBuilderBase &builder,
192                     LLVM::ModuleTranslation &moduleTranslation,
193                     acc::EnterDataOp op, SmallVector<uint64_t> &flags,
194                     SmallVector<llvm::Constant *> &names, unsigned &index,
195                     llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
196                     llvm::AllocaInst *argSizes) {
197   // TODO add `create_zero` and `attach` operands
198 
199   // Create operands are handled as `alloc` call.
200   if (failed(processOperands(builder, moduleTranslation, op,
201                              op.createOperands(), op.getNumDataOperands(),
202                              kCreateFlag, flags, names, index, argsBase, args,
203                              argSizes)))
204     return failure();
205 
206   // Copyin operands are handled as `to` call.
207   if (failed(processOperands(builder, moduleTranslation, op,
208                              op.copyinOperands(), op.getNumDataOperands(),
209                              kDeviceCopyinFlag, flags, names, index, argsBase,
210                              args, argSizes)))
211     return failure();
212 
213   return success();
214 }
215 
216 /// Process data operands from acc::ExitDataOp
217 static LogicalResult
218 processDataOperands(llvm::IRBuilderBase &builder,
219                     LLVM::ModuleTranslation &moduleTranslation,
220                     acc::ExitDataOp op, SmallVector<uint64_t> &flags,
221                     SmallVector<llvm::Constant *> &names, unsigned &index,
222                     llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
223                     llvm::AllocaInst *argSizes) {
224   // TODO add `detach` operands
225 
226   // Delete operands are handled as `delete` call.
227   if (failed(processOperands(builder, moduleTranslation, op,
228                              op.deleteOperands(), op.getNumDataOperands(),
229                              kDeleteFlag, flags, names, index, argsBase, args,
230                              argSizes)))
231     return failure();
232 
233   // Copyout operands are handled as `from` call.
234   if (failed(processOperands(builder, moduleTranslation, op,
235                              op.copyoutOperands(), op.getNumDataOperands(),
236                              kHostCopyoutFlag, flags, names, index, argsBase,
237                              args, argSizes)))
238     return failure();
239 
240   return success();
241 }
242 
243 /// Process data operands from acc::UpdateOp
244 static LogicalResult
245 processDataOperands(llvm::IRBuilderBase &builder,
246                     LLVM::ModuleTranslation &moduleTranslation,
247                     acc::UpdateOp op, SmallVector<uint64_t> &flags,
248                     SmallVector<llvm::Constant *> &names, unsigned &index,
249                     llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
250                     llvm::AllocaInst *argSizes) {
251 
252   // Host operands are handled as `from` call.
253   if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(),
254                              op.getNumDataOperands(), kHostCopyoutFlag, flags,
255                              names, index, argsBase, args, argSizes)))
256     return failure();
257 
258   // Device operands are handled as `to` call.
259   if (failed(processOperands(builder, moduleTranslation, op,
260                              op.deviceOperands(), op.getNumDataOperands(),
261                              kDeviceCopyinFlag, flags, names, index, argsBase,
262                              args, argSizes)))
263     return failure();
264 
265   return success();
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Conversion functions
270 //===----------------------------------------------------------------------===//
271 
272 /// Converts an OpenACC standalone data operation into LLVM IR.
273 template <typename OpTy>
274 static LogicalResult
275 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
276                         LLVM::ModuleTranslation &moduleTranslation) {
277   auto enclosingFuncOp =
278       op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
279   llvm::Function *enclosingFunction =
280       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
281 
282   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
283 
284   auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
285   auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
286 
287   // Number of arguments in the enter_data operation.
288   unsigned totalNbOperand = op.getNumDataOperands();
289 
290   // TODO could be moved to OpenXXIRBuilder?
291   llvm::LLVMContext &ctx = builder.getContext();
292   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
293   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
294   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
295   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
296   llvm::IRBuilder<>::InsertPoint allocaIP(
297       &enclosingFunction->getEntryBlock(),
298       enclosingFunction->getEntryBlock().getFirstInsertionPt());
299   llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP();
300   builder.restoreIP(allocaIP);
301   llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy);
302   llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy);
303   llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty);
304   builder.restoreIP(currentIP);
305 
306   SmallVector<uint64_t> flags;
307   SmallVector<llvm::Constant *> names;
308   unsigned index = 0;
309 
310   if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
311                                  index, argsBase, args, argSizes)))
312     return failure();
313 
314   llvm::GlobalVariable *maptypes =
315       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
316   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
317       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
318       maptypes, /*Idx0=*/0, /*Idx1=*/0);
319 
320   llvm::GlobalVariable *mapnames =
321       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
322   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
323       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
324       mapnames, /*Idx0=*/0, /*Idx1=*/0);
325 
326   llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP(
327       arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)});
328   llvm::Value *argsGEP = builder.CreateInBoundsGEP(
329       arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)});
330   llvm::Value *argSizesGEP = builder.CreateInBoundsGEP(
331       arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)});
332   llvm::Value *nullPtr = llvm::Constant::getNullValue(
333       llvm::Type::getInt8PtrTy(ctx)->getPointerTo());
334 
335   builder.CreateCall(mapperFunc,
336                      {srcLocInfo, builder.getInt64(kDefaultDevice),
337                       builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP,
338                       argSizesGEP, maptypesArg, mapnamesArg, nullPtr});
339 
340   return success();
341 }
342 
343 namespace {
344 
345 /// Implementation of the dialect interface that converts operations belonging
346 /// to the OpenACC dialect to LLVM IR.
347 class OpenACCDialectLLVMIRTranslationInterface
348     : public LLVMTranslationDialectInterface {
349 public:
350   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
351 
352   /// Translates the given operation to LLVM IR using the provided IR builder
353   /// and saving the state in `moduleTranslation`.
354   LogicalResult
355   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
356                    LLVM::ModuleTranslation &moduleTranslation) const final;
357 };
358 
359 } // end namespace
360 
361 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
362 /// (including OpenACC runtime calls).
363 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
364     Operation *op, llvm::IRBuilderBase &builder,
365     LLVM::ModuleTranslation &moduleTranslation) const {
366 
367   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
368       .Case([&](acc::EnterDataOp enterDataOp) {
369         return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
370                                                          moduleTranslation);
371       })
372       .Case([&](acc::ExitDataOp exitDataOp) {
373         return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
374                                                         moduleTranslation);
375       })
376       .Case([&](acc::UpdateOp updateOp) {
377         return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
378                                                       moduleTranslation);
379       })
380       .Default([&](Operation *op) {
381         return op->emitError("unsupported OpenACC operation: ")
382                << op->getName();
383       });
384 }
385 
386 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
387   registry.insert<acc::OpenACCDialect>();
388   registry.addDialectInterface<acc::OpenACCDialect,
389                                OpenACCDialectLLVMIRTranslationInterface>();
390 }
391 
392 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
393   DialectRegistry registry;
394   registerOpenACCDialectTranslation(registry);
395   context.appendDialectRegistry(registry);
396 }
397