1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
14 #include "mlir/Conversion/LLVMCommon/Pattern.h"
15 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
16 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// A pattern that converts the region arguments in a single-region OpenMP
25 /// operation to the LLVM dialect. The body of the region is not modified and is
26 /// expected to either be processed by the conversion infrastructure or already
27 /// contain ops compatible with LLVM dialect types.
28 template <typename OpType>
29 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
30   using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
31 
32   LogicalResult
33   matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
34                   ConversionPatternRewriter &rewriter) const override {
35     auto newOp = rewriter.create<OpType>(
36         curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
37     rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
38                                 newOp.region().end());
39     if (failed(rewriter.convertRegionTypes(&newOp.region(),
40                                            *this->getTypeConverter())))
41       return failure();
42 
43     rewriter.eraseOp(curOp);
44     return success();
45   }
46 };
47 } // namespace
48 
49 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
50                                                   RewritePatternSet &patterns) {
51   patterns.add<RegionOpConversion<omp::MasterOp>,
52                RegionOpConversion<omp::ParallelOp>,
53                RegionOpConversion<omp::WsLoopOp>>(converter);
54 }
55 
56 namespace {
57 struct ConvertOpenMPToLLVMPass
58     : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
59   void runOnOperation() override;
60 };
61 } // namespace
62 
63 void ConvertOpenMPToLLVMPass::runOnOperation() {
64   auto module = getOperation();
65 
66   // Convert to OpenMP operations with LLVM IR dialect
67   RewritePatternSet patterns(&getContext());
68   LLVMTypeConverter converter(&getContext());
69   mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
70   populateMemRefToLLVMConversionPatterns(converter, patterns);
71   populateStdToLLVMConversionPatterns(converter, patterns);
72   populateOpenMPToLLVMConversionPatterns(converter, patterns);
73 
74   LLVMConversionTarget target(getContext());
75   target.addDynamicallyLegalOp<omp::MasterOp, omp::ParallelOp, omp::WsLoopOp>(
76       [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
77   target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
78                     omp::BarrierOp, omp::TaskwaitOp>();
79   if (failed(applyPartialConversion(module, target, std::move(patterns))))
80     signalPassFailure();
81 }
82 
83 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
84   return std::make_unique<ConvertOpenMPToLLVMPass>();
85 }
86