1 //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.parallel operations into OpenMP
10 // parallel loops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
17 #include "mlir/Dialect/SCF/SCF.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 /// Converts SCF parallel operation into an OpenMP workshare loop construct.
25 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
26   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
27 
28   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
29                                 PatternRewriter &rewriter) const override {
30     // TODO: add support for reductions when OpenMP loops have them.
31     if (parallelOp.getNumResults() != 0)
32       return rewriter.notifyMatchFailure(
33           parallelOp,
34           "OpenMP dialect does not yet support loops with reductions");
35 
36     // Replace SCF yield with OpenMP yield.
37     {
38       OpBuilder::InsertionGuard guard(rewriter);
39       rewriter.setInsertionPointToEnd(parallelOp.getBody());
40       assert(llvm::hasSingleElement(parallelOp.region()) &&
41              "expected scf.parallel to have one block");
42       rewriter.replaceOpWithNewOp<omp::YieldOp>(
43           parallelOp.getBody()->getTerminator(), ValueRange());
44     }
45 
46     // Replace the loop.
47     auto omp = rewriter.create<omp::ParallelOp>(parallelOp.getLoc());
48     Block *block = rewriter.createBlock(&omp.getRegion());
49     rewriter.setInsertionPointToStart(block);
50     auto loop = rewriter.create<omp::WsLoopOp>(
51         parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
52         parallelOp.step());
53     rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
54                                 loop.region().begin());
55     rewriter.create<omp::TerminatorOp>(parallelOp.getLoc());
56 
57     rewriter.eraseOp(parallelOp);
58     return success();
59   }
60 };
61 
62 /// Applies the conversion patterns in the given function.
63 static LogicalResult applyPatterns(FuncOp func) {
64   ConversionTarget target(*func.getContext());
65   target.addIllegalOp<scf::ParallelOp>();
66   target.addDynamicallyLegalOp<scf::YieldOp>(
67       [](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
68   target.addLegalDialect<omp::OpenMPDialect>();
69 
70   RewritePatternSet patterns(func.getContext());
71   patterns.add<ParallelOpLowering>(func.getContext());
72   FrozenRewritePatternSet frozen(std::move(patterns));
73   return applyPartialConversion(func, target, frozen);
74 }
75 
76 /// A pass converting SCF operations to OpenMP operations.
77 struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
78   /// Pass entry point.
79   void runOnFunction() override {
80     if (failed(applyPatterns(getFunction())))
81       signalPassFailure();
82   }
83 };
84 
85 } // end namespace
86 
87 std::unique_ptr<OperationPass<FuncOp>> mlir::createConvertSCFToOpenMPPass() {
88   return std::make_unique<SCFToOpenMPPass>();
89 }
90