1875074c8SKiran Chandramohan //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===//
2875074c8SKiran Chandramohan //
3875074c8SKiran Chandramohan // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4875074c8SKiran Chandramohan // See https://llvm.org/LICENSE.txt for license information.
5875074c8SKiran Chandramohan // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6875074c8SKiran Chandramohan //
7875074c8SKiran Chandramohan //===----------------------------------------------------------------------===//
8875074c8SKiran Chandramohan
9875074c8SKiran Chandramohan #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
10875074c8SKiran Chandramohan
11875074c8SKiran Chandramohan #include "../PassDetail.h"
12a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
13ace01605SRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
145a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
155a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
1675e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
17684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
1875e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19875074c8SKiran Chandramohan #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20875074c8SKiran Chandramohan #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
21875074c8SKiran Chandramohan
22875074c8SKiran Chandramohan using namespace mlir;
23875074c8SKiran Chandramohan
24875074c8SKiran Chandramohan namespace {
25f7d033f4SAlex Zinenko /// A pattern that converts the region arguments in a single-region OpenMP
26f7d033f4SAlex Zinenko /// operation to the LLVM dialect. The body of the region is not modified and is
27f7d033f4SAlex Zinenko /// expected to either be processed by the conversion infrastructure or already
28f7d033f4SAlex Zinenko /// contain ops compatible with LLVM dialect types.
29f7d033f4SAlex Zinenko template <typename OpType>
30563879b6SRahul Joshi struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
31563879b6SRahul Joshi using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
32875074c8SKiran Chandramohan
33875074c8SKiran Chandramohan LogicalResult
matchAndRewrite__anon494baa170111::RegionOpConversion34ef976337SRiver Riddle matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
35875074c8SKiran Chandramohan ConversionPatternRewriter &rewriter) const override {
36ef976337SRiver Riddle auto newOp = rewriter.create<OpType>(
37ef976337SRiver Riddle curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
38875074c8SKiran Chandramohan rewriter.inlineRegionBefore(curOp.region(), newOp.region(),
39875074c8SKiran Chandramohan newOp.region().end());
40563879b6SRahul Joshi if (failed(rewriter.convertRegionTypes(&newOp.region(),
41563879b6SRahul Joshi *this->getTypeConverter())))
42875074c8SKiran Chandramohan return failure();
43875074c8SKiran Chandramohan
44563879b6SRahul Joshi rewriter.eraseOp(curOp);
45875074c8SKiran Chandramohan return success();
46875074c8SKiran Chandramohan }
47875074c8SKiran Chandramohan };
4800c511b3SNimish Mishra
4900c511b3SNimish Mishra template <typename T>
50dd32bf9aSKiran Chandramohan struct RegionLessOpWithVarOperandsConversion
51dd32bf9aSKiran Chandramohan : public ConvertOpToLLVMPattern<T> {
5200c511b3SNimish Mishra using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
5300c511b3SNimish Mishra LogicalResult
matchAndRewrite__anon494baa170111::RegionLessOpWithVarOperandsConversion5400c511b3SNimish Mishra matchAndRewrite(T curOp, typename T::Adaptor adaptor,
5500c511b3SNimish Mishra ConversionPatternRewriter &rewriter) const override {
56042ae895SPeixinQiao TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
57042ae895SPeixinQiao SmallVector<Type> resTypes;
58042ae895SPeixinQiao if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
59042ae895SPeixinQiao return failure();
60042ae895SPeixinQiao SmallVector<Value> convertedOperands;
61dd32bf9aSKiran Chandramohan assert(curOp.getNumVariableOperands() ==
62dd32bf9aSKiran Chandramohan curOp.getOperation()->getNumOperands() &&
63dd32bf9aSKiran Chandramohan "unexpected non-variable operands");
64042ae895SPeixinQiao for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
65042ae895SPeixinQiao Value originalVariableOperand = curOp.getVariableOperand(idx);
66042ae895SPeixinQiao if (!originalVariableOperand)
67042ae895SPeixinQiao return failure();
68042ae895SPeixinQiao if (originalVariableOperand.getType().isa<MemRefType>()) {
69042ae895SPeixinQiao // TODO: Support memref type in variable operands
70a5d7e2a8SAart Bik return rewriter.notifyMatchFailure(curOp,
71a5d7e2a8SAart Bik "memref is not supported yet");
72042ae895SPeixinQiao }
73118d9ebdSMehdi Amini convertedOperands.emplace_back(adaptor.getOperands()[idx]);
74042ae895SPeixinQiao }
75042ae895SPeixinQiao rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
7600c511b3SNimish Mishra curOp->getAttrs());
7700c511b3SNimish Mishra return success();
7800c511b3SNimish Mishra }
7900c511b3SNimish Mishra };
80*7bb1151bSKiran Chandramohan
81*7bb1151bSKiran Chandramohan struct ReductionOpConversion : public ConvertOpToLLVMPattern<omp::ReductionOp> {
82*7bb1151bSKiran Chandramohan using ConvertOpToLLVMPattern<omp::ReductionOp>::ConvertOpToLLVMPattern;
83*7bb1151bSKiran Chandramohan LogicalResult
matchAndRewrite__anon494baa170111::ReductionOpConversion84*7bb1151bSKiran Chandramohan matchAndRewrite(omp::ReductionOp curOp, OpAdaptor adaptor,
85*7bb1151bSKiran Chandramohan ConversionPatternRewriter &rewriter) const override {
86*7bb1151bSKiran Chandramohan if (curOp.accumulator().getType().isa<MemRefType>()) {
87*7bb1151bSKiran Chandramohan // TODO: Support memref type in variable operands
88*7bb1151bSKiran Chandramohan return rewriter.notifyMatchFailure(curOp, "memref is not supported yet");
89*7bb1151bSKiran Chandramohan }
90*7bb1151bSKiran Chandramohan rewriter.replaceOpWithNewOp<omp::ReductionOp>(
91*7bb1151bSKiran Chandramohan curOp, TypeRange(), adaptor.getOperands(), curOp->getAttrs());
92*7bb1151bSKiran Chandramohan return success();
93*7bb1151bSKiran Chandramohan }
94*7bb1151bSKiran Chandramohan };
95875074c8SKiran Chandramohan } // namespace
96875074c8SKiran Chandramohan
configureOpenMPToLLVMConversionLegality(ConversionTarget & target,LLVMTypeConverter & typeConverter)9700c511b3SNimish Mishra void mlir::configureOpenMPToLLVMConversionLegality(
9800c511b3SNimish Mishra ConversionTarget &target, LLVMTypeConverter &typeConverter) {
994ee9f3d5SKiran Chandramohan target.addDynamicallyLegalOp<mlir::omp::CriticalOp, mlir::omp::ParallelOp,
1004ee9f3d5SKiran Chandramohan mlir::omp::WsLoopOp, mlir::omp::MasterOp,
1014ee9f3d5SKiran Chandramohan mlir::omp::SectionsOp, mlir::omp::SingleOp>(
1024ee9f3d5SKiran Chandramohan [&](Operation *op) {
103dd32bf9aSKiran Chandramohan return typeConverter.isLegal(&op->getRegion(0)) &&
104dd32bf9aSKiran Chandramohan typeConverter.isLegal(op->getOperandTypes()) &&
105dd32bf9aSKiran Chandramohan typeConverter.isLegal(op->getResultTypes());
106dd32bf9aSKiran Chandramohan });
10700c511b3SNimish Mishra target
108042ae895SPeixinQiao .addDynamicallyLegalOp<mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp,
109dd32bf9aSKiran Chandramohan mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>(
110dd32bf9aSKiran Chandramohan [&](Operation *op) {
111dd32bf9aSKiran Chandramohan return typeConverter.isLegal(op->getOperandTypes()) &&
112dd32bf9aSKiran Chandramohan typeConverter.isLegal(op->getResultTypes());
11300c511b3SNimish Mishra });
114*7bb1151bSKiran Chandramohan target.addDynamicallyLegalOp<mlir::omp::ReductionOp>([&](Operation *op) {
115*7bb1151bSKiran Chandramohan return typeConverter.isLegal(op->getOperandTypes());
116*7bb1151bSKiran Chandramohan });
11700c511b3SNimish Mishra }
11800c511b3SNimish Mishra
populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns)119dc4e913bSChris Lattner void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
120dc4e913bSChris Lattner RewritePatternSet &patterns) {
121dd32bf9aSKiran Chandramohan patterns.add<
122*7bb1151bSKiran Chandramohan ReductionOpConversion, RegionOpConversion<omp::CriticalOp>,
123*7bb1151bSKiran Chandramohan RegionOpConversion<omp::MasterOp>, ReductionOpConversion,
124*7bb1151bSKiran Chandramohan RegionOpConversion<omp::MasterOp>, RegionOpConversion<omp::ParallelOp>,
125*7bb1151bSKiran Chandramohan RegionOpConversion<omp::WsLoopOp>, RegionOpConversion<omp::SectionsOp>,
126*7bb1151bSKiran Chandramohan RegionOpConversion<omp::SingleOp>,
127dd32bf9aSKiran Chandramohan RegionLessOpWithVarOperandsConversion<omp::AtomicReadOp>,
128dd32bf9aSKiran Chandramohan RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
129dd32bf9aSKiran Chandramohan RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
130dd32bf9aSKiran Chandramohan RegionLessOpWithVarOperandsConversion<omp::ThreadprivateOp>>(converter);
131875074c8SKiran Chandramohan }
132875074c8SKiran Chandramohan
133875074c8SKiran Chandramohan namespace {
134875074c8SKiran Chandramohan struct ConvertOpenMPToLLVMPass
135875074c8SKiran Chandramohan : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> {
136875074c8SKiran Chandramohan void runOnOperation() override;
137875074c8SKiran Chandramohan };
138875074c8SKiran Chandramohan } // namespace
139875074c8SKiran Chandramohan
runOnOperation()140875074c8SKiran Chandramohan void ConvertOpenMPToLLVMPass::runOnOperation() {
141875074c8SKiran Chandramohan auto module = getOperation();
142875074c8SKiran Chandramohan
143875074c8SKiran Chandramohan // Convert to OpenMP operations with LLVM IR dialect
144dc4e913bSChris Lattner RewritePatternSet patterns(&getContext());
145875074c8SKiran Chandramohan LLVMTypeConverter converter(&getContext());
146ace01605SRiver Riddle arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
147ace01605SRiver Riddle cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
14875e5f0aaSAlex Zinenko populateMemRefToLLVMConversionPatterns(converter, patterns);
1495a7b9194SRiver Riddle populateFuncToLLVMConversionPatterns(converter, patterns);
150563879b6SRahul Joshi populateOpenMPToLLVMConversionPatterns(converter, patterns);
151875074c8SKiran Chandramohan
152875074c8SKiran Chandramohan LLVMConversionTarget target(getContext());
153875074c8SKiran Chandramohan target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp,
154875074c8SKiran Chandramohan omp::BarrierOp, omp::TaskwaitOp>();
15500c511b3SNimish Mishra configureOpenMPToLLVMConversionLegality(target, converter);
1563fffffa8SRiver Riddle if (failed(applyPartialConversion(module, target, std::move(patterns))))
157875074c8SKiran Chandramohan signalPassFailure();
158875074c8SKiran Chandramohan }
159875074c8SKiran Chandramohan
createConvertOpenMPToLLVMPass()160875074c8SKiran Chandramohan std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() {
161875074c8SKiran Chandramohan return std::make_unique<ConvertOpenMPToLLVMPass>();
162875074c8SKiran Chandramohan }
163