1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert scf.parallel operations into OpenMP 10 // parallel loops. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h" 17 #include "mlir/Dialect/SCF/SCF.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 namespace { 23 24 /// Converts SCF parallel operation into an OpenMP workshare loop construct. 25 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { 26 using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; 27 28 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, 29 PatternRewriter &rewriter) const override { 30 // TODO: add support for reductions when OpenMP loops have them. 31 if (parallelOp.getNumResults() != 0) 32 return rewriter.notifyMatchFailure( 33 parallelOp, 34 "OpenMP dialect does not yet support loops with reductions"); 35 36 // Replace SCF yield with OpenMP yield. 37 { 38 OpBuilder::InsertionGuard guard(rewriter); 39 rewriter.setInsertionPointToEnd(parallelOp.getBody()); 40 assert(llvm::hasSingleElement(parallelOp.region()) && 41 "expected scf.parallel to have one block"); 42 rewriter.replaceOpWithNewOp<omp::YieldOp>( 43 parallelOp.getBody()->getTerminator(), ValueRange()); 44 } 45 46 // Replace the loop. 47 auto omp = rewriter.create<omp::ParallelOp>(parallelOp.getLoc()); 48 Block *block = rewriter.createBlock(&omp.getRegion()); 49 rewriter.setInsertionPointToStart(block); 50 auto loop = rewriter.create<omp::WsLoopOp>( 51 parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), 52 parallelOp.step()); 53 rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), 54 loop.region().begin()); 55 rewriter.create<omp::TerminatorOp>(parallelOp.getLoc()); 56 57 rewriter.eraseOp(parallelOp); 58 return success(); 59 } 60 }; 61 62 /// Applies the conversion patterns in the given function. 63 static LogicalResult applyPatterns(FuncOp func) { 64 ConversionTarget target(*func.getContext()); 65 target.addIllegalOp<scf::ParallelOp>(); 66 target.addDynamicallyLegalOp<scf::YieldOp>( 67 [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); }); 68 target.addLegalDialect<omp::OpenMPDialect>(); 69 70 RewritePatternSet patterns(func.getContext()); 71 patterns.add<ParallelOpLowering>(func.getContext()); 72 FrozenRewritePatternSet frozen(std::move(patterns)); 73 return applyPartialConversion(func, target, frozen); 74 } 75 76 /// A pass converting SCF operations to OpenMP operations. 77 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> { 78 /// Pass entry point. 79 void runOnFunction() override { 80 if (failed(applyPatterns(getFunction()))) 81 signalPassFailure(); 82 } 83 }; 84 85 } // end namespace 86 87 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() { 88 return std::make_unique<SCFToOpenMPPass>(); 89 } 90