1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10
11 #include "../PassDetail.h"
12 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17 #include "mlir/Conversion/LLVMCommon/Pattern.h"
18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21
22 using namespace mlir;
23
24 namespace {
25 /// A pattern that converts the region arguments in a single-region OpenMP
26 /// operation to the LLVM dialect. The body of the region is not modified and is
27 /// expected to either be processed by the conversion infrastructure or already
28 /// contain ops compatible with LLVM dialect types.
29 template <typename OpType>
30 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
31 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
32
33 LogicalResult
matchAndRewrite__anon494baa170111::RegionOpConversion34 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 auto newOp = rewriter.create<OpType>(
37 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
38 rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
39 newOp.region().end());
40 if (failed(rewriter.convertRegionTypes(&newOp.region(),
41 *this->getTypeConverter())))
42 return failure();
43
44 rewriter.eraseOp(curOp);
45 return success();
46 }
47 };
48
49 template <typename T>
50 struct RegionLessOpWithVarOperandsConversion
51 : public ConvertOpToLLVMPattern<T> {
52 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
53 LogicalResult
matchAndRewrite__anon494baa170111::RegionLessOpWithVarOperandsConversion54 matchAndRewrite(T curOp, typename T::Adaptor adaptor,
55 ConversionPatternRewriter &rewriter) const override {
56 TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
57 SmallVector<Type> resTypes;
58 if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
59 return failure();
60 SmallVector<Value> convertedOperands;
61 assert(curOp.getNumVariableOperands() ==
62 curOp.getOperation()->getNumOperands() &&
63 "unexpected non-variable operands");
64 for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
65 Value originalVariableOperand = curOp.getVariableOperand(idx);
66 if (!originalVariableOperand)
67 return failure();
68 if (originalVariableOperand.getType().isa<MemRefType>()) {
69 // TODO: Support memref type in variable operands
70 return rewriter.notifyMatchFailure(curOp,
71 "memref is not supported yet");
72 }
73 convertedOperands.emplace_back(adaptor.getOperands()[idx]);
74 }
75 rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
76 curOp->getAttrs());
77 return success();
78 }
79 };
80
81 struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
82 using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern;
83 LogicalResult
matchAndRewrite__anon494baa170111::ReductionOpConversion84 matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor,
85 ConversionPatternRewriter &rewriter) const override {
86 if (curOp.accumulator().getType().isa<MemRefType>()) {
87 // TODO: Support memref type in variable operands
88 return rewriter.notifyMatchFailure(curOp, "memref is not supported yet");
89 }
90 rewriter.replaceOpWithNewOp<omp::ReductionOp>(
91 curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs());
92 return success();
93 }
94 };
95 } // namespace
96
configureOpenMPToLLVMConversionLegality(ConversionTarget & target,LLVMTypeConverter & typeConverter)97 void mlir::configureOpenMPToLLVMConversionLegality(
98 ConversionTarget &target, LLVMTypeConverter &typeConverter) {
99 target.addDynamicallyLegalOp<mlir::omp::CriticalOp, mlir::omp::ParallelOp,
100 mlir::omp::WsLoopOp, mlir::omp::MasterOp,
101 mlir::omp::SectionsOp, mlir::omp::SingleOp>(
102 [&](Operation *op) {
103 return typeConverter.isLegal(&op->getRegion(0)) &&
104 typeConverter.isLegal(op->getOperandTypes()) &&
105 typeConverter.isLegal(op->getResultTypes());
106 });
107 target
108 .addDynamicallyLegalOp<mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp,
109 mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>(
110 [&](Operation *op) {
111 return typeConverter.isLegal(op->getOperandTypes()) &&
112 typeConverter.isLegal(op->getResultTypes());
113 });
114 target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) {
115 return typeConverter.isLegal(op->getOperandTypes());
116 });
117 }
118
populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)119 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
120 RewritePatternSet &patterns) {
121 patterns.add<
122 ReductionOpConversion, RegionOpConversion<omp::CriticalOp>,
123 RegionOpConversion<omp::MasterOp>, ReductionOpConversion,
124 RegionOpConversion<omp::MasterOp>, RegionOpConversion<omp::ParallelOp>,
125 RegionOpConversion<omp::WsLoopOp>, RegionOpConversion<omp::SectionsOp>,
126 RegionOpConversion<omp::SingleOp>,
127 RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
128 RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
129 RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
130 RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>>(converter);
131 }
132
133 namespace {
134 struct ConvertOpenMPToLLVMPass
135 : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
136 void runOnOperation() override;
137 };
138 } // namespace
139
runOnOperation()140 void ConvertOpenMPToLLVMPass::runOnOperation() {
141 auto module = getOperation();
142
143 // Convert to OpenMP operations with LLVM IR dialect
144 RewritePatternSet patterns(&getContext());
145 LLVMTypeConverter converter(&getContext());
146 arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
147 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
148 populateMemRefToLLVMConversionPatterns(converter, patterns);
149 populateFuncToLLVMConversionPatterns(converter, patterns);
150 populateOpenMPToLLVMConversionPatterns(converter, patterns);
151
152 LLVMConversionTarget target(getContext());
153 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
154 omp::BarrierOp, omp::TaskwaitOp>();
155 configureOpenMPToLLVMConversionLegality(target, converter);
156 if (failed(applyPartialConversion(module, target, std::move(patterns))))
157 signalPassFailure();
158 }
159
createConvertOpenMPToLLVMPass()160 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
161 return std::make_unique<ConvertOpenMPToLLVMPass>();
162 }
163