1 //===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "../PassDetail.h"
10 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/OpenACC/OpenACC.h"
13 #include "mlir/Dialect/SCF/SCF.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 using namespace mlir;
18 
19 //===----------------------------------------------------------------------===//
20 // Conversion patterns
21 //===----------------------------------------------------------------------===//
22 
23 namespace {
24 /// Pattern to transform the `ifCond` on operation without region into a scf.if
25 /// and move the operation into the `then` region.
26 template <typename OpTy>
27 class ExpandIfCondition : public OpRewritePattern<OpTy> {
28   using OpRewritePattern<OpTy>::OpRewritePattern;
29 
30   LogicalResult matchAndRewrite(OpTy op,
31                                 PatternRewriter &rewriter) const override {
32     // Early exit if there is no condition.
33     if (!op.ifCond())
34       return success();
35 
36     // Condition is not a constant.
37     if (!op.ifCond().template getDefiningOp<arith::ConstantOp>()) {
38       auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
39                                              op.ifCond(), false);
40       rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); });
41       auto thenBodyBuilder = ifOp.getThenBodyBuilder();
42       thenBodyBuilder.setListener(rewriter.getListener());
43       thenBodyBuilder.clone(*op.getOperation());
44       rewriter.eraseOp(op);
45     }
46 
47     return success();
48   }
49 };
50 } // namespace
51 
52 void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
53   patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext());
54   patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext());
55   patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext());
56 }
57 
58 namespace {
59 struct ConvertOpenACCToSCFPass
60     : public ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> {
61   void runOnOperation() override;
62 };
63 } // namespace
64 
65 void ConvertOpenACCToSCFPass::runOnOperation() {
66   auto op = getOperation();
67   auto *context = op.getContext();
68 
69   RewritePatternSet patterns(context);
70   ConversionTarget target(*context);
71   populateOpenACCToSCFConversionPatterns(patterns);
72 
73   target.addLegalDialect<scf::SCFDialect>();
74   target.addLegalDialect<acc::OpenACCDialect>();
75 
76   target.addDynamicallyLegalOp<acc::EnterDataOp>(
77       [](acc::EnterDataOp op) { return !op.ifCond(); });
78 
79   target.addDynamicallyLegalOp<acc::ExitDataOp>(
80       [](acc::ExitDataOp op) { return !op.ifCond(); });
81 
82   target.addDynamicallyLegalOp<acc::UpdateOp>(
83       [](acc::UpdateOp op) { return !op.ifCond(); });
84 
85   if (failed(applyPartialConversion(op, target, std::move(patterns))))
86     signalPassFailure();
87 }
88 
89 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() {
90   return std::make_unique<ConvertOpenACCToSCFPass>();
91 }
92