1 //===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "../PassDetail.h" 10 #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/OpenACC/OpenACC.h" 13 #include "mlir/Dialect/SCF/IR/SCF.h" 14 #include "mlir/Transforms/DialectConversion.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // Conversion patterns 20 //===----------------------------------------------------------------------===// 21 22 namespace { 23 /// Pattern to transform the `ifCond` on operation without region into a scf.if 24 /// and move the operation into the `then` region. 25 template <typename OpTy> 26 class ExpandIfCondition : public OpRewritePattern<OpTy> { 27 using OpRewritePattern<OpTy>::OpRewritePattern; 28 matchAndRewrite(OpTy op,PatternRewriter & rewriter) const29 LogicalResult matchAndRewrite(OpTy op, 30 PatternRewriter &rewriter) const override { 31 // Early exit if there is no condition. 32 if (!op.ifCond()) 33 return success(); 34 35 // Condition is not a constant. 36 if (!op.ifCond().template getDefiningOp<arith::ConstantOp>()) { 37 auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(), 38 op.ifCond(), false); 39 rewriter.updateRootInPlace(op, [&]() { op.ifCondMutable().erase(0); }); 40 auto thenBodyBuilder = ifOp.getThenBodyBuilder(); 41 thenBodyBuilder.setListener(rewriter.getListener()); 42 thenBodyBuilder.clone(*op.getOperation()); 43 rewriter.eraseOp(op); 44 } 45 46 return success(); 47 } 48 }; 49 } // namespace 50 populateOpenACCToSCFConversionPatterns(RewritePatternSet & patterns)51void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) { 52 patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext()); 53 patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext()); 54 patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext()); 55 } 56 57 namespace { 58 struct ConvertOpenACCToSCFPass 59 : public ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> { 60 void runOnOperation() override; 61 }; 62 } // namespace 63 runOnOperation()64void ConvertOpenACCToSCFPass::runOnOperation() { 65 auto op = getOperation(); 66 auto *context = op.getContext(); 67 68 RewritePatternSet patterns(context); 69 ConversionTarget target(*context); 70 populateOpenACCToSCFConversionPatterns(patterns); 71 72 target.addLegalDialect<scf::SCFDialect>(); 73 target.addLegalDialect<acc::OpenACCDialect>(); 74 75 target.addDynamicallyLegalOp<acc::EnterDataOp>( 76 [](acc::EnterDataOp op) { return !op.ifCond(); }); 77 78 target.addDynamicallyLegalOp<acc::ExitDataOp>( 79 [](acc::ExitDataOp op) { return !op.ifCond(); }); 80 81 target.addDynamicallyLegalOp<acc::UpdateOp>( 82 [](acc::UpdateOp op) { return !op.ifCond(); }); 83 84 if (failed(applyPartialConversion(op, target, std::move(patterns)))) 85 signalPassFailure(); 86 } 87 createConvertOpenACCToSCFPass()88std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() { 89 return std::make_unique<ConvertOpenACCToSCFPass>(); 90 } 91