1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 13 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 14 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 16 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 17 #include "mlir/Conversion/LLVMCommon/Pattern.h" 18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 21 22 using namespace mlir; 23 24 namespace { 25 /// A pattern that converts the region arguments in a single-region OpenMP 26 /// operation to the LLVM dialect. The body of the region is not modified and is 27 /// expected to either be processed by the conversion infrastructure or already 28 /// contain ops compatible with LLVM dialect types. 29 template <typename OpType> 30 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> { 31 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; 32 33 LogicalResult 34 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, 35 ConversionPatternRewriter &rewriter) const override { 36 auto newOp = rewriter.create<OpType>( 37 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); 38 rewriter.inlineRegionBefore(curOp.region(), newOp.region(), 39 newOp.region().end()); 40 if (failed(rewriter.convertRegionTypes(&newOp.region(), 41 *this->getTypeConverter()))) 42 return failure(); 43 44 rewriter.eraseOp(curOp); 45 return success(); 46 } 47 }; 48 49 template <typename T> 50 struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> { 51 using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern; 52 LogicalResult 53 matchAndRewrite(T curOp, typename T::Adaptor adaptor, 54 ConversionPatternRewriter &rewriter) const override { 55 rewriter.replaceOpWithNewOp<T>(curOp, TypeRange(), adaptor.getOperands(), 56 curOp->getAttrs()); 57 return success(); 58 } 59 }; 60 } // namespace 61 62 void mlir::configureOpenMPToLLVMConversionLegality( 63 ConversionTarget &target, LLVMTypeConverter &typeConverter) { 64 target.addDynamicallyLegalOp<mlir::omp::ParallelOp, mlir::omp::WsLoopOp, 65 mlir::omp::MasterOp>( 66 [&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)); }); 67 target 68 .addDynamicallyLegalOp<mlir::omp::AtomicReadOp, mlir::omp::AtomicWriteOp>( 69 [&](Operation *op) { 70 return typeConverter.isLegal(op->getOperandTypes()); 71 }); 72 } 73 74 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, 75 RewritePatternSet &patterns) { 76 patterns.add<RegionOpConversion<omp::MasterOp>, 77 RegionOpConversion<omp::ParallelOp>, 78 RegionOpConversion<omp::WsLoopOp>, 79 RegionLessOpConversion<omp::AtomicReadOp>, 80 RegionLessOpConversion<omp::AtomicWriteOp>>(converter); 81 } 82 83 namespace { 84 struct ConvertOpenMPToLLVMPass 85 : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> { 86 void runOnOperation() override; 87 }; 88 } // namespace 89 90 void ConvertOpenMPToLLVMPass::runOnOperation() { 91 auto module = getOperation(); 92 93 // Convert to OpenMP operations with LLVM IR dialect 94 RewritePatternSet patterns(&getContext()); 95 LLVMTypeConverter converter(&getContext()); 96 arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); 97 cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 98 populateMemRefToLLVMConversionPatterns(converter, patterns); 99 populateFuncToLLVMConversionPatterns(converter, patterns); 100 populateOpenMPToLLVMConversionPatterns(converter, patterns); 101 102 LLVMConversionTarget target(getContext()); 103 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, 104 omp::BarrierOp, omp::TaskwaitOp>(); 105 configureOpenMPToLLVMConversionLegality(target, converter); 106 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 107 signalPassFailure(); 108 } 109 110 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() { 111 return std::make_unique<ConvertOpenMPToLLVMPass>(); 112 } 113