1 //===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" 10 11 #include "../PassDetail.h" 12 #include "mlir/Conversion/LLVMCommon/Pattern.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 17 18 using namespace mlir; 19 20 namespace { 21 /// A pattern that converts the region arguments in a single-region OpenMP 22 /// operation to the LLVM dialect. The body of the region is not modified and is 23 /// expected to either be processed by the conversion infrastructure or already 24 /// contain ops compatible with LLVM dialect types. 25 template <typename OpType> 26 struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> { 27 using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern; 28 29 LogicalResult 30 matchAndRewrite(OpType curOp, ArrayRef<Value> operands, 31 ConversionPatternRewriter &rewriter) const override { 32 auto newOp = rewriter.create<OpType>(curOp.getLoc(), TypeRange(), operands, 33 curOp->getAttrs()); 34 rewriter.inlineRegionBefore(curOp.region(), newOp.region(), 35 newOp.region().end()); 36 if (failed(rewriter.convertRegionTypes(&newOp.region(), 37 *this->getTypeConverter()))) 38 return failure(); 39 40 rewriter.eraseOp(curOp); 41 return success(); 42 } 43 }; 44 } // namespace 45 46 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, 47 RewritePatternSet &patterns) { 48 patterns.add<RegionOpConversion<omp::ParallelOp>, 49 RegionOpConversion<omp::WsLoopOp>>(converter); 50 } 51 52 namespace { 53 struct ConvertOpenMPToLLVMPass 54 : public ConvertOpenMPToLLVMBase<ConvertOpenMPToLLVMPass> { 55 void runOnOperation() override; 56 }; 57 } // namespace 58 59 void ConvertOpenMPToLLVMPass::runOnOperation() { 60 auto module = getOperation(); 61 62 // Convert to OpenMP operations with LLVM IR dialect 63 RewritePatternSet patterns(&getContext()); 64 LLVMTypeConverter converter(&getContext()); 65 populateStdToLLVMConversionPatterns(converter, patterns); 66 populateOpenMPToLLVMConversionPatterns(converter, patterns); 67 68 LLVMConversionTarget target(getContext()); 69 target.addDynamicallyLegalOp<omp::ParallelOp, omp::WsLoopOp>( 70 [&](Operation *op) { return converter.isLegal(&op->getRegion(0)); }); 71 target.addLegalOp<omp::TerminatorOp, omp::TaskyieldOp, omp::FlushOp, 72 omp::BarrierOp, omp::TaskwaitOp>(); 73 if (failed(applyPartialConversion(module, target, std::move(patterns)))) 74 signalPassFailure(); 75 } 76 77 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenMPToLLVMPass() { 78 return std::make_unique<ConvertOpenMPToLLVMPass>(); 79 } 80