1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 13 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 14 #include "mlir/Conversion/LLVMCommon/Pattern.h" 15 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 16 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 17 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 18 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 19 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// A pattern that converts the region arguments in a single-region OpenMP 25 /// operation to the LLVM dialect. The body of the region is not modified and is 26 /// expected to either be processed by the conversion infrastructure or already 27 /// contain ops compatible with LLVM dialect types. 28 template <typename OpType> 29 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> { 30 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; 31 32 LogicalResult 33 matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor, 34 ConversionPatternRewriter &rewriter) const override { 35 auto newOp = rewriter.create<OpType>( 36 curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs()); 37 rewriter.inlineRegionBefore(curOp.region(), newOp.region(), 38 newOp.region().end()); 39 if (failed(rewriter.convertRegionTypes(&newOp.region(), 40 *this->getTypeConverter()))) 41 return failure(); 42 43 rewriter.eraseOp(curOp); 44 return success(); 45 } 46 }; 47 } // namespace 48 49 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, 50 RewritePatternSet &patterns) { 51 patterns.add<RegionOpConversion<omp::MasterOp>, 52 RegionOpConversion<omp::ParallelOp>, 53 RegionOpConversion<omp::WsLoopOp>>(converter); 54 } 55 56 namespace { 57 struct ConvertOpenMPToLLVMPass 58 : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> { 59 void runOnOperation() override; 60 }; 61 } // namespace 62 63 void ConvertOpenMPToLLVMPass::runOnOperation() { 64 auto module = getOperation(); 65 66 // Convert to OpenMP operations with LLVM IR dialect 67 RewritePatternSet patterns(&getContext()); 68 LLVMTypeConverter converter(&getContext()); 69 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); 70 populateMemRefToLLVMConversionPatterns(converter, patterns); 71 populateStdToLLVMConversionPatterns(converter, patterns); 72 populateOpenMPToLLVMConversionPatterns(converter, patterns); 73 74 LLVMConversionTarget target(getContext()); 75 target.addDynamicallyLegalOp<omp::MasterOp, omp::ParallelOp, omp::WsLoopOp>( 76 [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); }); 77 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, 78 omp::BarrierOp, omp::TaskwaitOp>(); 79 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 80 signalPassFailure(); 81 } 82 83 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() { 84 return std::make_unique<ConvertOpenMPToLLVMPass>(); 85 } 86