1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/LLVMCommon/Pattern.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 /// A pattern that converts the region arguments in a single-region OpenMP
22 /// operation to the LLVM dialect. The body of the region is not modified and is
23 /// expected to either be processed by the conversion infrastructure or already
24 /// contain ops compatible with LLVM dialect types.
25 template <typename OpType>
26 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
27   using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
28 
29   LogicalResult
30   matchAndRewrite(OpType curOp, ArrayRef<Value> operands,
31                   ConversionPatternRewriter &rewriter) const override {
32     auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands,
33                                          curOp->getAttrs());
34     rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
35                                 newOp.region().end());
36     if (failed(rewriter.convertRegionTypes(&newOp.region(),
37                                            *this->getTypeConverter())))
38       return failure();
39 
40     rewriter.eraseOp(curOp);
41     return success();
42   }
43 };
44 } // namespace
45 
46 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
47                                                   RewritePatternSet &patterns) {
48   patterns.add<RegionOpConversion<omp::ParallelOp>,
49                RegionOpConversion<omp::WsLoopOp>>(converter);
50 }
51 
52 namespace {
53 struct ConvertOpenMPToLLVMPass
54     : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
55   void runOnOperation() override;
56 };
57 } // namespace
58 
59 void ConvertOpenMPToLLVMPass::runOnOperation() {
60   auto module = getOperation();
61 
62   // Convert to OpenMP operations with LLVM IR dialect
63   RewritePatternSet patterns(&getContext());
64   LLVMTypeConverter converter(&getContext());
65   populateStdToLLVMConversionPatterns(converter, patterns);
66   populateOpenMPToLLVMConversionPatterns(converter, patterns);
67 
68   LLVMConversionTarget target(getContext());
69   target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>(
70       [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); });
71   target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
72                     omp::BarrierOp, omp::TaskwaitOp>();
73   if (failed(applyPartialConversion(module, target, std::move(patterns))))
74     signalPassFailure();
75 }
76 
77 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
78   return std::make_unique<ConvertOpenMPToLLVMPass>();
79 }
80