1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10 
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 /// A pattern that converts the region arguments in a single-region OpenMP
26 /// operation to the LLVM dialect. The body of the region is not modified and is
27 /// expected to either be processed by the conversion infrastructure or already
28 /// contain ops compatible with LLVM dialect types.
29 template <typename OpType>
30 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
31   using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
32 
33   LogicalResult
34   matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override {
36     auto newOp = rewriter.create<OpType>(
37         curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
38     rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
39                                 newOp.region().end());
40     if (failed(rewriter.convertRegionTypes(&newOp.region(),
41                                            *this->getTypeConverter())))
42       return failure();
43 
44     rewriter.eraseOp(curOp);
45     return success();
46   }
47 };
48 
49 template <typename T>
50 struct RegionLessOpWithVarOperandsConversion
51     : public ConvertOpToLLVMPattern<T> {
52   using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
53   LogicalResult
54   matchAndRewrite(T curOp, typename T::Adaptor adaptor,
55                   ConversionPatternRewriter &rewriter) const override {
56     TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
57     SmallVector<Type> resTypes;
58     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
59       return failure();
60     SmallVector<Value> convertedOperands;
61     assert(curOp.getNumVariableOperands() ==
62                curOp.getOperation()->getNumOperands() &&
63            "unexpected non-variable operands");
64     for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
65       Value originalVariableOperand = curOp.getVariableOperand(idx);
66       if (!originalVariableOperand)
67         return failure();
68       if (originalVariableOperand.getType().isa<MemRefType>()) {
69         // TODO: Support memref type in variable operands
70         return rewriter.notifyMatchFailure(curOp,
71                                            "memref is not supported yet");
72       }
73       convertedOperands.emplace_back(adaptor.getOperands()[idx]);
74     }
75     rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
76                                    curOp->getAttrs());
77     return success();
78   }
79 };
80 } // namespace
81 
82 void mlir::configureOpenMPToLLVMConversionLegality(
83     ConversionTarget &target, LLVMTypeConverter &typeConverter) {
84   target.addDynamicallyLegalOp<mlir::omp::CriticalOp, mlir::omp::ParallelOp,
85                                mlir::omp::WsLoopOp, mlir::omp::MasterOp,
86                                mlir::omp::SectionsOp, mlir::omp::SingleOp>(
87       [&](Operation *op) {
88         return typeConverter.isLegal(&op->getRegion(0)) &&
89                typeConverter.isLegal(op->getOperandTypes()) &&
90                typeConverter.isLegal(op->getResultTypes());
91       });
92   target
93       .addDynamicallyLegalOp<mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp,
94                              mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>(
95           [&](Operation *op) {
96             return typeConverter.isLegal(op->getOperandTypes()) &&
97                    typeConverter.isLegal(op->getResultTypes());
98           });
99 }
100 
101 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
102                                                   RewritePatternSet &patterns) {
103   patterns.add<
104       RegionOpConversion<omp::CriticalOp>, RegionOpConversion<omp::MasterOp>,
105       RegionOpConversion<omp::ParallelOp>, RegionOpConversion<omp::WsLoopOp>,
106       RegionOpConversion<omp::SectionsOp>, RegionOpConversion<omp::SingleOp>,
107       RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
108       RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
109       RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
110       RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>>(converter);
111 }
112 
113 namespace {
114 struct ConvertOpenMPToLLVMPass
115     : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
116   void runOnOperation() override;
117 };
118 } // namespace
119 
120 void ConvertOpenMPToLLVMPass::runOnOperation() {
121   auto module = getOperation();
122 
123   // Convert to OpenMP operations with LLVM IR dialect
124   RewritePatternSet patterns(&getContext());
125   LLVMTypeConverter converter(&getContext());
126   arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
127   cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
128   populateMemRefToLLVMConversionPatterns(converter, patterns);
129   populateFuncToLLVMConversionPatterns(converter, patterns);
130   populateOpenMPToLLVMConversionPatterns(converter, patterns);
131 
132   LLVMConversionTarget target(getContext());
133   target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
134                     omp::BarrierOp, omp::TaskwaitOp>();
135   configureOpenMPToLLVMConversionLegality(target, converter);
136   if (failed(applyPartialConversion(module, target, std::move(patterns))))
137     signalPassFailure();
138 }
139 
140 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
141   return std::make_unique<ConvertOpenMPToLLVMPass>();
142 }
143