1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 struct ParallelOpConversion : public ConvertToLLVMPattern {
20   explicit ParallelOpConversion(MLIRContext *context,
21                                 LLVMTypeConverter &typeConverter)
22       : ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context,
23                              typeConverter) {}
24 
25   LogicalResult
26   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
27                   ConversionPatternRewriter &rewriter) const override {
28     auto curOp = cast<omp::ParallelOp>(op);
29     auto newOp = rewriter.create<omp::ParallelOp>(
30         curOp.getLoc(), ArrayRef<Type>(), operands, curOp.getAttrs());
31     rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
32                                 newOp.region().end());
33     if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter)))
34       return failure();
35 
36     rewriter.eraseOp(op);
37     return success();
38   }
39 };
40 } // namespace
41 
42 void mlir::populateOpenMPToLLVMConversionPatterns(
43     MLIRContext *context, LLVMTypeConverter &converter,
44     OwningRewritePatternList &patterns) {
45   patterns.insert<ParallelOpConversion>(context, converter);
46 }
47 
48 namespace {
49 struct ConvertOpenMPToLLVMPass
50     : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
51   void runOnOperation() override;
52 };
53 } // namespace
54 
55 void ConvertOpenMPToLLVMPass::runOnOperation() {
56   auto module = getOperation();
57   MLIRContext *context = &getContext();
58 
59   // Convert to OpenMP operations with LLVM IR dialect
60   OwningRewritePatternList patterns;
61   LLVMTypeConverter converter(&getContext());
62   populateStdToLLVMConversionPatterns(converter, patterns);
63   populateOpenMPToLLVMConversionPatterns(context, converter, patterns);
64 
65   LLVMConversionTarget target(getContext());
66   target.addDynamicallyLegalOp<omp::ParallelOp>(
67       [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); });
68   target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
69                     omp::BarrierOp, omp::TaskwaitOp>();
70   if (failed(applyPartialConversion(module, target, patterns)))
71     signalPassFailure();
72 }
73 
74 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
75   return std::make_unique<ConvertOpenMPToLLVMPass>();
76 }
77