1 //===- OpenACCToLLVMIRTranslation.cpp -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a translation between the MLIR OpenACC dialect and LLVM
10 // IR.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.h"
15 #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/OpenACC/OpenACC.h"
18 #include "mlir/IR/BuiltinOps.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/Support/LLVM.h"
21 #include "mlir/Target/LLVMIR/ModuleTranslation.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 #include "llvm/Frontend/OpenMP/OMPConstants.h"
25 #include "llvm/Support/FormatVariadic.h"
26 
27 using namespace mlir;
28 
29 using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
30 
31 //===----------------------------------------------------------------------===//
32 // Utility functions
33 //===----------------------------------------------------------------------===//
34 
35 /// Flag values are extracted from openmp/libomptarget/include/omptarget.h and
36 /// mapped to corresponding OpenACC flags.
37 static constexpr uint64_t kCreateFlag = 0x000;
38 static constexpr uint64_t kDeviceCopyinFlag = 0x001;
39 static constexpr uint64_t kHostCopyoutFlag = 0x002;
40 static constexpr uint64_t kCopyFlag = kDeviceCopyinFlag | kHostCopyoutFlag;
41 static constexpr uint64_t kPresentFlag = 0x1000;
42 static constexpr uint64_t kDeleteFlag = 0x008;
43 // Runtime extension to implement the OpenACC second reference counter.
44 static constexpr uint64_t kHoldFlag = 0x2000;
45 
46 /// Default value for the device id
47 static constexpr int64_t kDefaultDevice = -1;
48 
49 /// Create a constant string location from the MLIR Location information.
createSourceLocStrFromLocation(Location loc,OpenACCIRBuilder & builder,StringRef name,uint32_t & strLen)50 static llvm::Constant *createSourceLocStrFromLocation(Location loc,
51                                                       OpenACCIRBuilder &builder,
52                                                       StringRef name,
53                                                       uint32_t &strLen) {
54   if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
55     StringRef fileName = fileLoc.getFilename();
56     unsigned lineNo = fileLoc.getLine();
57     unsigned colNo = fileLoc.getColumn();
58     return builder.getOrCreateSrcLocStr(name, fileName, lineNo, colNo, strLen);
59   }
60   std::string locStr;
61   llvm::raw_string_ostream locOS(locStr);
62   locOS << loc;
63   return builder.getOrCreateSrcLocStr(locOS.str(), strLen);
64 }
65 
66 /// Create the location struct from the operation location information.
createSourceLocationInfo(OpenACCIRBuilder & builder,Operation * op)67 static llvm::Value *createSourceLocationInfo(OpenACCIRBuilder &builder,
68                                              Operation *op) {
69   auto loc = op->getLoc();
70   auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
71   StringRef funcName = funcOp ? funcOp.getName() : "unknown";
72   uint32_t strLen;
73   llvm::Constant *locStr =
74       createSourceLocStrFromLocation(loc, builder, funcName, strLen);
75   return builder.getOrCreateIdent(locStr, strLen);
76 }
77 
78 /// Create a constant string representing the mapping information extracted from
79 /// the MLIR location information.
createMappingInformation(Location loc,OpenACCIRBuilder & builder)80 static llvm::Constant *createMappingInformation(Location loc,
81                                                 OpenACCIRBuilder &builder) {
82   uint32_t strLen;
83   if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
84     StringRef name = nameLoc.getName();
85     return createSourceLocStrFromLocation(nameLoc.getChildLoc(), builder, name,
86                                           strLen);
87   }
88   return createSourceLocStrFromLocation(loc, builder, "unknown", strLen);
89 }
90 
91 /// Return the runtime function used to lower the given operation.
getAssociatedFunction(OpenACCIRBuilder & builder,Operation * op)92 static llvm::Function *getAssociatedFunction(OpenACCIRBuilder &builder,
93                                              Operation *op) {
94   return llvm::TypeSwitch<Operation *, llvm::Function *>(op)
95       .Case([&](acc::EnterDataOp) {
96         return builder.getOrCreateRuntimeFunctionPtr(
97             llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
98       })
99       .Case([&](acc::ExitDataOp) {
100         return builder.getOrCreateRuntimeFunctionPtr(
101             llvm::omp::OMPRTL___tgt_target_data_end_mapper);
102       })
103       .Case([&](acc::UpdateOp) {
104         return builder.getOrCreateRuntimeFunctionPtr(
105             llvm::omp::OMPRTL___tgt_target_data_update_mapper);
106       });
107   llvm_unreachable("Unknown OpenACC operation");
108 }
109 
110 /// Computes the size of type in bytes.
getSizeInBytes(llvm::IRBuilderBase & builder,llvm::Value * basePtr)111 static llvm::Value *getSizeInBytes(llvm::IRBuilderBase &builder,
112                                    llvm::Value *basePtr) {
113   llvm::LLVMContext &ctx = builder.getContext();
114   llvm::Value *null =
115       llvm::Constant::getNullValue(basePtr->getType()->getPointerTo());
116   llvm::Value *sizeGep =
117       builder.CreateGEP(basePtr->getType(), null, builder.getInt32(1));
118   llvm::Value *sizePtrToInt =
119       builder.CreatePtrToInt(sizeGep, llvm::Type::getInt64Ty(ctx));
120   return sizePtrToInt;
121 }
122 
123 /// Extract pointer, size and mapping information from operands
124 /// to populate the future functions arguments.
125 static LogicalResult
processOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,Operation * op,ValueRange operands,unsigned totalNbOperand,uint64_t operandFlag,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,unsigned & index,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)126 processOperands(llvm::IRBuilderBase &builder,
127                 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
128                 ValueRange operands, unsigned totalNbOperand,
129                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
130                 SmallVectorImpl<llvm::Constant *> &names, unsigned &index,
131                 struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
132   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
133   llvm::LLVMContext &ctx = builder.getContext();
134   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
135   auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
136   auto *i64Ty = llvm::Type::getInt64Ty(ctx);
137   auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
138 
139   for (Value data : operands) {
140     llvm::Value *dataValue = moduleTranslation.lookupValue(data);
141 
142     llvm::Value *dataPtrBase;
143     llvm::Value *dataPtr;
144     llvm::Value *dataSize;
145 
146     // Handle operands that were converted to DataDescriptor.
147     if (DataDescriptor::isValid(data)) {
148       dataPtrBase =
149           builder.CreateExtractValue(dataValue, kPtrBasePosInDataDescriptor);
150       dataPtr = builder.CreateExtractValue(dataValue, kPtrPosInDataDescriptor);
151       dataSize =
152           builder.CreateExtractValue(dataValue, kSizePosInDataDescriptor);
153     } else if (data.getType().isa<LLVM::LLVMPointerType>()) {
154       dataPtrBase = dataValue;
155       dataPtr = dataValue;
156       dataSize = getSizeInBytes(builder, dataValue);
157     } else {
158       return op->emitOpError()
159              << "Data operand must be legalized before translation."
160              << "Unsupported type: " << data.getType();
161     }
162 
163     // Store base pointer extracted from operand into the i-th position of
164     // argBase.
165     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
166         arrI8PtrTy, mapperAllocas.ArgsBase,
167         {builder.getInt32(0), builder.getInt32(index)});
168     llvm::Value *ptrBaseCast = builder.CreateBitCast(
169         ptrBaseGEP, dataPtrBase->getType()->getPointerTo());
170     builder.CreateStore(dataPtrBase, ptrBaseCast);
171 
172     // Store pointer extracted from operand into the i-th position of args.
173     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
174         arrI8PtrTy, mapperAllocas.Args,
175         {builder.getInt32(0), builder.getInt32(index)});
176     llvm::Value *ptrCast =
177         builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo());
178     builder.CreateStore(dataPtr, ptrCast);
179 
180     // Store size extracted from operand into the i-th position of argSizes.
181     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
182         arrI64Ty, mapperAllocas.ArgSizes,
183         {builder.getInt32(0), builder.getInt32(index)});
184     builder.CreateStore(dataSize, sizeGEP);
185 
186     flags.push_back(operandFlag);
187     llvm::Constant *mapName =
188         createMappingInformation(data.getLoc(), *accBuilder);
189     names.push_back(mapName);
190     ++index;
191   }
192   return success();
193 }
194 
195 /// Process data operands from acc::EnterDataOp
196 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::EnterDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)197 processDataOperands(llvm::IRBuilderBase &builder,
198                     LLVM::ModuleTranslation &moduleTranslation,
199                     acc::EnterDataOp op, SmallVector<uint64_t> &flags,
200                     SmallVectorImpl<llvm::Constant *> &names,
201                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
202   // TODO add `create_zero` and `attach` operands
203 
204   unsigned index = 0;
205 
206   // Create operands are handled as `alloc` call.
207   if (failed(processOperands(builder, moduleTranslation, op,
208                              op.createOperands(), op.getNumDataOperands(),
209                              kCreateFlag, flags, names, index, mapperAllocas)))
210     return failure();
211 
212   // Copyin operands are handled as `to` call.
213   if (failed(processOperands(builder, moduleTranslation, op,
214                              op.copyinOperands(), op.getNumDataOperands(),
215                              kDeviceCopyinFlag, flags, names, index,
216                              mapperAllocas)))
217     return failure();
218 
219   return success();
220 }
221 
222 /// Process data operands from acc::ExitDataOp
223 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::ExitDataOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)224 processDataOperands(llvm::IRBuilderBase &builder,
225                     LLVM::ModuleTranslation &moduleTranslation,
226                     acc::ExitDataOp op, SmallVector<uint64_t> &flags,
227                     SmallVectorImpl<llvm::Constant *> &names,
228                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
229   // TODO add `detach` operands
230 
231   unsigned index = 0;
232 
233   // Delete operands are handled as `delete` call.
234   if (failed(processOperands(builder, moduleTranslation, op,
235                              op.deleteOperands(), op.getNumDataOperands(),
236                              kDeleteFlag, flags, names, index, mapperAllocas)))
237     return failure();
238 
239   // Copyout operands are handled as `from` call.
240   if (failed(processOperands(builder, moduleTranslation, op,
241                              op.copyoutOperands(), op.getNumDataOperands(),
242                              kHostCopyoutFlag, flags, names, index,
243                              mapperAllocas)))
244     return failure();
245 
246   return success();
247 }
248 
249 /// Process data operands from acc::UpdateOp
250 static LogicalResult
processDataOperands(llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation,acc::UpdateOp op,SmallVector<uint64_t> & flags,SmallVectorImpl<llvm::Constant * > & names,struct OpenACCIRBuilder::MapperAllocas & mapperAllocas)251 processDataOperands(llvm::IRBuilderBase &builder,
252                     LLVM::ModuleTranslation &moduleTranslation,
253                     acc::UpdateOp op, SmallVector<uint64_t> &flags,
254                     SmallVectorImpl<llvm::Constant *> &names,
255                     struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
256   unsigned index = 0;
257 
258   // Host operands are handled as `from` call.
259   if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(),
260                              op.getNumDataOperands(), kHostCopyoutFlag, flags,
261                              names, index, mapperAllocas)))
262     return failure();
263 
264   // Device operands are handled as `to` call.
265   if (failed(processOperands(builder, moduleTranslation, op,
266                              op.deviceOperands(), op.getNumDataOperands(),
267                              kDeviceCopyinFlag, flags, names, index,
268                              mapperAllocas)))
269     return failure();
270 
271   return success();
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // Conversion functions
276 //===----------------------------------------------------------------------===//
277 
278 /// Converts an OpenACC data operation into LLVM IR.
convertDataOp(acc::DataOp & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)279 static LogicalResult convertDataOp(acc::DataOp &op,
280                                    llvm::IRBuilderBase &builder,
281                                    LLVM::ModuleTranslation &moduleTranslation) {
282   llvm::LLVMContext &ctx = builder.getContext();
283   auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
284   llvm::Function *enclosingFunction =
285       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
286 
287   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
288 
289   llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
290 
291   llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
292       llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
293 
294   llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
295       llvm::omp::OMPRTL___tgt_target_data_end_mapper);
296 
297   // Number of arguments in the data operation.
298   unsigned totalNbOperand = op.getNumDataOperands();
299 
300   struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
301   OpenACCIRBuilder::InsertPointTy allocaIP(
302       &enclosingFunction->getEntryBlock(),
303       enclosingFunction->getEntryBlock().getFirstInsertionPt());
304   accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
305                                   mapperAllocas);
306 
307   SmallVector<uint64_t> flags;
308   SmallVector<llvm::Constant *> names;
309   unsigned index = 0;
310 
311   // TODO handle no_create, deviceptr and attach operands.
312 
313   if (failed(processOperands(builder, moduleTranslation, op, op.copyOperands(),
314                              totalNbOperand, kCopyFlag | kHoldFlag, flags,
315                              names, index, mapperAllocas)))
316     return failure();
317 
318   if (failed(processOperands(
319           builder, moduleTranslation, op, op.copyinOperands(), totalNbOperand,
320           kDeviceCopyinFlag | kHoldFlag, flags, names, index, mapperAllocas)))
321     return failure();
322 
323   // TODO copyin readonly currenlty handled as copyin. Update when extension
324   // available.
325   if (failed(processOperands(builder, moduleTranslation, op,
326                              op.copyinReadonlyOperands(), totalNbOperand,
327                              kDeviceCopyinFlag | kHoldFlag, flags, names, index,
328                              mapperAllocas)))
329     return failure();
330 
331   if (failed(processOperands(
332           builder, moduleTranslation, op, op.copyoutOperands(), totalNbOperand,
333           kHostCopyoutFlag | kHoldFlag, flags, names, index, mapperAllocas)))
334     return failure();
335 
336   // TODO copyout zero currenlty handled as copyout. Update when extension
337   // available.
338   if (failed(processOperands(builder, moduleTranslation, op,
339                              op.copyoutZeroOperands(), totalNbOperand,
340                              kHostCopyoutFlag | kHoldFlag, flags, names, index,
341                              mapperAllocas)))
342     return failure();
343 
344   if (failed(processOperands(
345           builder, moduleTranslation, op, op.createOperands(), totalNbOperand,
346           kCreateFlag | kHoldFlag, flags, names, index, mapperAllocas)))
347     return failure();
348 
349   // TODO create zero currenlty handled as create. Update when extension
350   // available.
351   if (failed(processOperands(builder, moduleTranslation, op,
352                              op.createZeroOperands(), totalNbOperand,
353                              kCreateFlag | kHoldFlag, flags, names, index,
354                              mapperAllocas)))
355     return failure();
356 
357   if (failed(processOperands(
358           builder, moduleTranslation, op, op.presentOperands(), totalNbOperand,
359           kPresentFlag | kHoldFlag, flags, names, index, mapperAllocas)))
360     return failure();
361 
362   llvm::GlobalVariable *maptypes =
363       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
364   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
365       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
366       maptypes, /*Idx0=*/0, /*Idx1=*/0);
367 
368   llvm::GlobalVariable *mapnames =
369       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
370   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
371       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
372       mapnames, /*Idx0=*/0, /*Idx1=*/0);
373 
374   // Create call to start the data region.
375   accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo,
376                              maptypesArg, mapnamesArg, mapperAllocas,
377                              kDefaultDevice, totalNbOperand);
378 
379   // Convert the region.
380   llvm::BasicBlock *entryBlock = nullptr;
381 
382   for (Block &bb : op.region()) {
383     llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
384         ctx, "acc.data", builder.GetInsertBlock()->getParent());
385     if (entryBlock == nullptr)
386       entryBlock = llvmBB;
387     moduleTranslation.mapBlock(&bb, llvmBB);
388   }
389 
390   auto afterDataRegion = builder.saveIP();
391 
392   llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock);
393 
394   builder.restoreIP(afterDataRegion);
395   llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
396       ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
397 
398   SetVector<Block *> blocks =
399       LLVM::detail::getTopologicallySortedBlocks(op.region());
400   for (Block *bb : blocks) {
401     llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
402     if (bb->isEntryBlock()) {
403       assert(sourceTerminator->getNumSuccessors() == 1 &&
404              "provided entry block has multiple successors");
405       sourceTerminator->setSuccessor(0, llvmBB);
406     }
407 
408     if (failed(
409             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
410       return failure();
411     }
412 
413     if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator()))
414       builder.CreateBr(endDataBlock);
415   }
416 
417   // Create call to end the data region.
418   builder.SetInsertPoint(endDataBlock);
419   accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo,
420                              maptypesArg, mapnamesArg, mapperAllocas,
421                              kDefaultDevice, totalNbOperand);
422 
423   return success();
424 }
425 
426 /// Converts an OpenACC standalone data operation into LLVM IR.
427 template <typename OpTy>
428 static LogicalResult
convertStandaloneDataOp(OpTy & op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation)429 convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
430                         LLVM::ModuleTranslation &moduleTranslation) {
431   auto enclosingFuncOp =
432       op.getOperation()->template getParentOfType<LLVM::LLVMFuncOp>();
433   llvm::Function *enclosingFunction =
434       moduleTranslation.lookupFunction(enclosingFuncOp.getName());
435 
436   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
437 
438   auto *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
439   auto *mapperFunc = getAssociatedFunction(*accBuilder, op);
440 
441   // Number of arguments in the enter_data operation.
442   unsigned totalNbOperand = op.getNumDataOperands();
443 
444   llvm::LLVMContext &ctx = builder.getContext();
445 
446   struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
447   OpenACCIRBuilder::InsertPointTy allocaIP(
448       &enclosingFunction->getEntryBlock(),
449       enclosingFunction->getEntryBlock().getFirstInsertionPt());
450   accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
451                                   mapperAllocas);
452 
453   SmallVector<uint64_t> flags;
454   SmallVector<llvm::Constant *> names;
455 
456   if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
457                                  mapperAllocas)))
458     return failure();
459 
460   llvm::GlobalVariable *maptypes =
461       accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
462   llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
463       llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
464       maptypes, /*Idx0=*/0, /*Idx1=*/0);
465 
466   llvm::GlobalVariable *mapnames =
467       accBuilder->createOffloadMapnames(names, ".offload_mapnames");
468   llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
469       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
470       mapnames, /*Idx0=*/0, /*Idx1=*/0);
471 
472   accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo,
473                              maptypesArg, mapnamesArg, mapperAllocas,
474                              kDefaultDevice, totalNbOperand);
475 
476   return success();
477 }
478 
479 namespace {
480 
481 /// Implementation of the dialect interface that converts operations belonging
482 /// to the OpenACC dialect to LLVM IR.
483 class OpenACCDialectLLVMIRTranslationInterface
484     : public LLVMTranslationDialectInterface {
485 public:
486   using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
487 
488   /// Translates the given operation to LLVM IR using the provided IR builder
489   /// and saving the state in `moduleTranslation`.
490   LogicalResult
491   convertOperation(Operation *op, llvm::IRBuilderBase &builder,
492                    LLVM::ModuleTranslation &moduleTranslation) const final;
493 };
494 
495 } // namespace
496 
497 /// Given an OpenACC MLIR operation, create the corresponding LLVM IR
498 /// (including OpenACC runtime calls).
convertOperation(Operation * op,llvm::IRBuilderBase & builder,LLVM::ModuleTranslation & moduleTranslation) const499 LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
500     Operation *op, llvm::IRBuilderBase &builder,
501     LLVM::ModuleTranslation &moduleTranslation) const {
502 
503   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
504       .Case([&](acc::DataOp dataOp) {
505         return convertDataOp(dataOp, builder, moduleTranslation);
506       })
507       .Case([&](acc::EnterDataOp enterDataOp) {
508         return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
509                                                          moduleTranslation);
510       })
511       .Case([&](acc::ExitDataOp exitDataOp) {
512         return convertStandaloneDataOp<acc::ExitDataOp>(exitDataOp, builder,
513                                                         moduleTranslation);
514       })
515       .Case([&](acc::UpdateOp updateOp) {
516         return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
517                                                       moduleTranslation);
518       })
519       .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) {
520         // `yield` and `terminator` can be just omitted. The block structure was
521         // created in the function that handles their parent operation.
522         assert(op->getNumOperands() == 0 &&
523                "unexpected OpenACC terminator with operands");
524         return success();
525       })
526       .Default([&](Operation *op) {
527         return op->emitError("unsupported OpenACC operation: ")
528                << op->getName();
529       });
530 }
531 
registerOpenACCDialectTranslation(DialectRegistry & registry)532 void mlir::registerOpenACCDialectTranslation(DialectRegistry &registry) {
533   registry.insert<acc::OpenACCDialect>();
534   registry.addExtension(+[](MLIRContext *ctx, acc::OpenACCDialect *dialect) {
535     dialect->addInterfaces<OpenACCDialectLLVMIRTranslationInterface>();
536   });
537 }
538 
registerOpenACCDialectTranslation(MLIRContext & context)539 void mlir::registerOpenACCDialectTranslation(MLIRContext &context) {
540   DialectRegistry registry;
541   registerOpenACCDialectTranslation(registry);
542   context.appendDialectRegistry(registry);
543 }
544