1 //===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../PassDetail.h" 10 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/OpenACC/OpenACC.h" 13 #include "mlir/Dialect/SCF/SCF.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/Transforms/DialectConversion.h" 16 17 using namespace mlir; 18 19 //===----------------------------------------------------------------------===// 20 // Conversion patterns 21 //===----------------------------------------------------------------------===// 22 23 namespace { 24 /// Pattern to transform the `ifCond` on operation without region into a scf.if 25 /// and move the operation into the `then` region. 26 template <typename OpTy> 27 class ExpandIfCondition : public OpRewritePattern<OpTy> { 28 using OpRewritePattern<OpTy>::OpRewritePattern; 29 30 LogicalResult matchAndRewrite(OpTy op, 31 PatternRewriter &rewriter) const override { 32 // Early exit if there is no condition. 33 if (!op.ifCond()) 34 return success(); 35 36 // Condition is not a constant. 37 if (!op.ifCond().template getDefiningOp<arith::ConstantOp>()) { 38 auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(), 39 op.ifCond(), false); 40 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 41 auto thenBodyBuilder = ifOp.getThenBodyBuilder(); 42 thenBodyBuilder.setListener(rewriter.getListener()); 43 thenBodyBuilder.clone(*op.getOperation()); 44 rewriter.eraseOp(op); 45 } 46 47 return success(); 48 } 49 }; 50 } // namespace 51 52 void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) { 53 patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext()); 54 patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext()); 55 patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext()); 56 } 57 58 namespace { 59 struct ConvertOpenACCToSCFPass 60 : public ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> { 61 void runOnOperation() override; 62 }; 63 } // namespace 64 65 void ConvertOpenACCToSCFPass::runOnOperation() { 66 auto op = getOperation(); 67 auto *context = op.getContext(); 68 69 RewritePatternSet patterns(context); 70 ConversionTarget target(*context); 71 populateOpenACCToSCFConversionPatterns(patterns); 72 73 target.addLegalDialect<scf::SCFDialect>(); 74 target.addLegalDialect<acc::OpenACCDialect>(); 75 76 target.addDynamicallyLegalOp<acc::EnterDataOp>( 77 [](acc::EnterDataOp op) { return !op.ifCond(); }); 78 79 target.addDynamicallyLegalOp<acc::ExitDataOp>( 80 [](acc::ExitDataOp op) { return !op.ifCond(); }); 81 82 target.addDynamicallyLegalOp<acc::UpdateOp>( 83 [](acc::UpdateOp op) { return !op.ifCond(); }); 84 85 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 86 signalPassFailure(); 87 } 88 89 std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() { 90 return std::make_unique<ConvertOpenACCToSCFPass>(); 91 } 92