1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/OpenACC/OpenACC.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// 0 = alloc/create
36 static constexpr uint64_t createFlag = 0;
37 /// 1 = to/copyin
38 static constexpr uint64_t copyinFlag = 1;
39 /// 2 = from/copyout
40 static constexpr uint64_t copyoutFlag = 2;
41 /// 8 = delete
42 static constexpr uint64_t deleteFlag = 8;
43 
44 /// Default value for the device id
45 static constexpr int64_t defaultDevice = -1;
46 
47 /// Create a constant string location from the MLIR Location information.
48 static llvm::Constant *createSourceLocStrFromLocation(Location loc,
49                                                       OpenACCIRBuilder &builder,
50                                                       StringRef name) {
51   if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
52     StringRef fileName = fileLoc.getFilename();
53     unsigned lineNo = fileLoc.getLine();
54     unsigned colNo = fileLoc.getColumn();
55     return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo);
56   } else {
57     std::string locStr;
58     llvm::raw_string_ostream locOS(locStr);
59     locOS << loc;
60     return builder.getOrCreateSrcLocStr(locOS.str());
61   }
62 }
63 
64 /// Create the location struct from the operation location information.
65 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
66                                              Operation *op) {
67   auto loc = op->getLoc();
68   auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
69   StringRef funcName = funcOp ? funcOp.getName() : "unknown";
70   llvm::Constant *locStr =
71       createSourceLocStrFromLocation(loc, builder, funcName);
72   return builder.getOrCreateIdent(locStr);
73 }
74 
75 /// Create a constant string representing the mapping information extracted from
76 /// the MLIR location information.
77 static llvm::Constant *createMappingInformation(Location loc,
78                                                 OpenACCIRBuilder &builder) {
79   if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
80     StringRef name = nameLoc.getName();
81     return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name);
82   } else {
83     return createSourceLocStrFromLocation(loc, builder, "unknown");
84   }
85 }
86 
87 /// Return the runtime function used to lower the given operation.
88 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
89                                              Operation *op) {
90   return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
91       .Case([&](acc::EnterDataOp) {
92         return builder.getOrCreateRuntimeFunctionPtr(
93             llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
94       })
95       .Case([&](acc::ExitDataOp) {
96         return builder.getOrCreateRuntimeFunctionPtr(
97             llvm::omp::OMPRTL___tgt_target_data_end_mapper);
98       });
99   llvm_unreachable("Unknown OpenACC operation");
100 }
101 
102 /// Computes the size of type in bytes.
103 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder,
104                                    llvm::Value *basePtr) {
105   llvm::LLVMContext &ctx = builder.getContext();
106   llvm::Value *null =
107       llvm::Constant::getNullValue(basePtr->getType()->getPointerTo());
108   llvm::Value *sizeGep =
109       builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1));
110   llvm::Value *sizePtrToInt =
111       builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx));
112   return sizePtrToInt;
113 }
114 
115 /// Extract pointer, size and mapping information from operands
116 /// to populate the future functions arguments.
117 static LogicalResult
118 processOperands(llvm::IRBuilderBase &builder,
119                 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
120                 ValueRange operands, unsigned totalNbOperand,
121                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
122                 SmallVector<llvm::Constant *> &names, unsigned &index,
123                 llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
124                 llvm::AllocaInst *argSizes) {
125   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
126   llvm::LLVMContext &ctx = builder.getContext();
127   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
128   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
129   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
130   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
131 
132   for (Value data : operands) {
133     llvm::Value *dataValue = moduleTranslation.lookupValue(data);
134 
135     llvm::Value *dataPtrBase;
136     llvm::Value *dataPtr;
137     llvm::Value *dataSize;
138 
139     // Handle operands that were converted to DataDescriptor.
140     if (DataDescriptor::isValid(data)) {
141       dataPtrBase =
142           builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor);
143       dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor);
144       dataSize =
145           builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor);
146     } else if (data.getType().isa<LLVM::LLVMPointerType>()) {
147       dataPtrBase = dataValue;
148       dataPtr = dataValue;
149       dataSize = getSizeInBytes(builder, dataValue);
150     } else {
151       return op->emitOpError()
152              << "Data operand must be legalized before translation."
153              << "Unsupported type: " << data.getType();
154     }
155 
156     // Store base pointer extracted from operand into the i-th position of
157     // argBase.
158     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
159         arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)});
160     llvm::Value *ptrBaseCast = builder.CreateBitCast(
161         ptrBaseGEP, dataPtrBase->getType()->getPointerTo());
162     builder.CreateStore(dataPtrBase, ptrBaseCast);
163 
164     // Store pointer extracted from operand into the i-th position of args.
165     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
166         arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)});
167     llvm::Value *ptrCast =
168         builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo());
169     builder.CreateStore(dataPtr, ptrCast);
170 
171     // Store size extracted from operand into the i-th position of argSizes.
172     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
173         arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)});
174     builder.CreateStore(dataSize, sizeGEP);
175 
176     flags.push_back(operandFlag);
177     llvm::Constant *mapName =
178         createMappingInformation(data.getLoc(), *accBuilder);
179     names.push_back(mapName);
180     ++index;
181   }
182   return success();
183 }
184 
185 /// Process data operands from acc::EnterDataOp
186 static LogicalResult
187 processDataOperands(llvm::IRBuilderBase &builder,
188                     LLVM::ModuleTranslation &moduleTranslation,
189                     acc::EnterDataOp op, SmallVector<uint64_t> &flags,
190                     SmallVector<llvm::Constant *> &names, unsigned &index,
191                     llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
192                     llvm::AllocaInst *argSizes) {
193   // TODO add `create_zero` and `attach` operands
194 
195   // Create operands are handled as `alloc` call.
196   if (failed(processOperands(builder, moduleTranslation, op,
197                              op.createOperands(), op.getNumDataOperands(),
198                              createFlag, flags, names, index, argsBase, args,
199                              argSizes)))
200     return failure();
201 
202   // Copyin operands are handled as `to` call.
203   if (failed(processOperands(builder, moduleTranslation, op,
204                              op.copyinOperands(), op.getNumDataOperands(),
205                              copyinFlag, flags, names, index, argsBase, args,
206                              argSizes)))
207     return failure();
208 
209   return success();
210 }
211 
212 /// Process data operands from acc::ExitDataOp
213 static LogicalResult
214 processDataOperands(llvm::IRBuilderBase &builder,
215                     LLVM::ModuleTranslation &moduleTranslation,
216                     acc::ExitDataOp op, SmallVector<uint64_t> &flags,
217                     SmallVector<llvm::Constant *> &names, unsigned &index,
218                     llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
219                     llvm::AllocaInst *argSizes) {
220   // TODO add `detach` operands
221 
222   // Delete operands are handled as `delete` call.
223   if (failed(processOperands(builder, moduleTranslation, op,
224                              op.deleteOperands(), op.getNumDataOperands(),
225                              deleteFlag, flags, names, index, argsBase, args,
226                              argSizes)))
227     return failure();
228 
229   // Copyout operands are handled as `from` call.
230   if (failed(processOperands(builder, moduleTranslation, op,
231                              op.copyoutOperands(), op.getNumDataOperands(),
232                              copyoutFlag, flags, names, index, argsBase, args,
233                              argSizes)))
234     return failure();
235 
236   return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Conversion functions
241 //===----------------------------------------------------------------------===//
242 
243 /// Converts an OpenACC standalone data operation into LLVM IR.
244 template <typename OpTy>
245 static LogicalResult
246 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
247                         LLVM::ModuleTranslation &moduleTranslation) {
248   auto enclosingFuncOp =
249       op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
250   llvm::Function *enclosingFunction =
251       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
252 
253   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
254 
255   auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
256   auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
257 
258   // Number of arguments in the enter_data operation.
259   unsigned totalNbOperand = op.getNumDataOperands();
260 
261   // TODO could be moved to OpenXXIRBuilder?
262   llvm::LLVMContext &ctx = builder.getContext();
263   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
264   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
265   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
266   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
267   llvm::IRBuilder<>::InsertPoint allocaIP(
268       &enclosingFunction->getEntryBlock(),
269       enclosingFunction->getEntryBlock().getFirstInsertionPt());
270   llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP();
271   builder.restoreIP(allocaIP);
272   llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy);
273   llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy);
274   llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty);
275   builder.restoreIP(currentIP);
276 
277   SmallVector<uint64_t> flags;
278   SmallVector<llvm::Constant *> names;
279   unsigned index = 0;
280 
281   if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
282                                  index, argsBase, args, argSizes)))
283     return failure();
284 
285   llvm::GlobalVariable *maptypes =
286       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
287   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
288       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
289       maptypes, /*Idx0=*/0, /*Idx1=*/0);
290 
291   llvm::GlobalVariable *mapnames =
292       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
293   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
294       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
295       mapnames, /*Idx0=*/0, /*Idx1=*/0);
296 
297   llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP(
298       arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)});
299   llvm::Value *argsGEP = builder.CreateInBoundsGEP(
300       arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)});
301   llvm::Value *argSizesGEP = builder.CreateInBoundsGEP(
302       arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)});
303   llvm::Value *nullPtr = llvm::Constant::getNullValue(
304       llvm::Type::getInt8PtrTy(ctx)->getPointerTo());
305 
306   builder.CreateCall(mapperFunc,
307                      {srcLocInfo, builder.getInt64(defaultDevice),
308                       builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP,
309                       argSizesGEP, maptypesArg, mapnamesArg, nullPtr});
310 
311   return success();
312 }
313 
314 namespace {
315 
316 /// Implementation of the dialect interface that converts operations belonging
317 /// to the OpenACC dialect to LLVM IR.
318 class OpenACCDialectLLVMIRTranslationInterface
319     : public LLVMTranslationDialectInterface {
320 public:
321   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
322 
323   /// Translates the given operation to LLVM IR using the provided IR builder
324   /// and saving the state in `moduleTranslation`.
325   LogicalResult
326   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
327                    LLVM::ModuleTranslation &moduleTranslation) const final;
328 };
329 
330 } // end namespace
331 
332 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
333 /// (including OpenACC runtime calls).
334 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
335     Operation *op, llvm::IRBuilderBase &builder,
336     LLVM::ModuleTranslation &moduleTranslation) const {
337 
338   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
339       .Case([&](acc::EnterDataOp enterDataOp) {
340         return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
341                                                          moduleTranslation);
342       })
343       .Case([&](acc::ExitDataOp exitDataOp) {
344         return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
345                                                         moduleTranslation);
346       })
347       .Default([&](Operation *op) {
348         return op->emitError("unsupported OpenACC operation: ")
349                << op->getName();
350       });
351 }
352 
353 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
354   registry.insert<acc::OpenACCDialect>();
355   registry.addDialectInterface<acc::OpenACCDialect,
356                                OpenACCDialectLLVMIRTranslationInterface>();
357 }
358 
359 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
360   DialectRegistry registry;
361   registerOpenACCDialectTranslation(registry);
362   context.appendDialectRegistry(registry);
363 }
364