197d8972cSEric Schweitz //===-- PreCGRewrite.cpp --------------------------------------------------===//
297d8972cSEric Schweitz //
397d8972cSEric Schweitz // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
497d8972cSEric Schweitz // See https://llvm.org/LICENSE.txt for license information.
597d8972cSEric Schweitz // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
697d8972cSEric Schweitz //
797d8972cSEric Schweitz //===----------------------------------------------------------------------===//
897d8972cSEric Schweitz //
997d8972cSEric Schweitz // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
1097d8972cSEric Schweitz //
1197d8972cSEric Schweitz //===----------------------------------------------------------------------===//
1297d8972cSEric Schweitz
1397d8972cSEric Schweitz #include "CGOps.h"
1497d8972cSEric Schweitz #include "PassDetail.h"
1597d8972cSEric Schweitz #include "flang/Optimizer/CodeGen/CodeGen.h"
1697d8972cSEric Schweitz #include "flang/Optimizer/Dialect/FIRDialect.h"
1797d8972cSEric Schweitz #include "flang/Optimizer/Dialect/FIROps.h"
1897d8972cSEric Schweitz #include "flang/Optimizer/Dialect/FIRType.h"
1997d8972cSEric Schweitz #include "flang/Optimizer/Support/FIRContext.h"
2097d8972cSEric Schweitz #include "mlir/Transforms/DialectConversion.h"
2197d8972cSEric Schweitz #include "llvm/ADT/STLExtras.h"
22861eff24SNico Weber #include "llvm/Support/Debug.h"
2397d8972cSEric Schweitz
2497d8972cSEric Schweitz //===----------------------------------------------------------------------===//
2597d8972cSEric Schweitz // Codegen rewrite: rewriting of subgraphs of ops
2697d8972cSEric Schweitz //===----------------------------------------------------------------------===//
2797d8972cSEric Schweitz
2897d8972cSEric Schweitz #define DEBUG_TYPE "flang-codegen-rewrite"
2997d8972cSEric Schweitz
populateShape(llvm::SmallVectorImpl<mlir::Value> & vec,fir::ShapeOp shape)3097d8972cSEric Schweitz static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
311f31795cSEric Schweitz fir::ShapeOp shape) {
32149ad3d5SShraiysh Vaishay vec.append(shape.getExtents().begin(), shape.getExtents().end());
3397d8972cSEric Schweitz }
3497d8972cSEric Schweitz
3597d8972cSEric Schweitz // Operands of fir.shape_shift split into two vectors.
populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> & shapeVec,llvm::SmallVectorImpl<mlir::Value> & shiftVec,fir::ShapeShiftOp shift)3697d8972cSEric Schweitz static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
3797d8972cSEric Schweitz llvm::SmallVectorImpl<mlir::Value> &shiftVec,
381f31795cSEric Schweitz fir::ShapeShiftOp shift) {
391f31795cSEric Schweitz for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
401f31795cSEric Schweitz i != endIter;) {
4197d8972cSEric Schweitz shiftVec.push_back(*i++);
4297d8972cSEric Schweitz shapeVec.push_back(*i++);
4397d8972cSEric Schweitz }
4497d8972cSEric Schweitz }
4597d8972cSEric Schweitz
populateShift(llvm::SmallVectorImpl<mlir::Value> & vec,fir::ShiftOp shift)4697d8972cSEric Schweitz static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
471f31795cSEric Schweitz fir::ShiftOp shift) {
48149ad3d5SShraiysh Vaishay vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
4997d8972cSEric Schweitz }
5097d8972cSEric Schweitz
5197d8972cSEric Schweitz namespace {
5297d8972cSEric Schweitz
5397d8972cSEric Schweitz /// Convert fir.embox to the extended form where necessary.
5497d8972cSEric Schweitz ///
5597d8972cSEric Schweitz /// The embox operation can take arguments that specify multidimensional array
5697d8972cSEric Schweitz /// properties at runtime. These properties may be shared between distinct
5797d8972cSEric Schweitz /// objects that have the same properties. Before we lower these small DAGs to
5897d8972cSEric Schweitz /// LLVM-IR, we gather all the information into a single extended operation. For
5997d8972cSEric Schweitz /// example,
6097d8972cSEric Schweitz /// ```
6197d8972cSEric Schweitz /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
6297d8972cSEric Schweitz /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
63a54f4eaeSMogball /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
64a54f4eaeSMogball /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
6597d8972cSEric Schweitz /// ```
6697d8972cSEric Schweitz /// can be rewritten as
6797d8972cSEric Schweitz /// ```
68a54f4eaeSMogball /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
69a54f4eaeSMogball /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
70a54f4eaeSMogball /// !fir.box<!fir.array<?xi32>>
7197d8972cSEric Schweitz /// ```
721f31795cSEric Schweitz class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> {
7397d8972cSEric Schweitz public:
7497d8972cSEric Schweitz using OpRewritePattern::OpRewritePattern;
7597d8972cSEric Schweitz
7697d8972cSEric Schweitz mlir::LogicalResult
matchAndRewrite(fir::EmboxOp embox,mlir::PatternRewriter & rewriter) const771f31795cSEric Schweitz matchAndRewrite(fir::EmboxOp embox,
7897d8972cSEric Schweitz mlir::PatternRewriter &rewriter) const override {
7997d8972cSEric Schweitz // If the embox does not include a shape, then do not convert it
801f31795cSEric Schweitz if (auto shapeVal = embox.getShape())
8197d8972cSEric Schweitz return rewriteDynamicShape(embox, rewriter, shapeVal);
821f31795cSEric Schweitz if (auto boxTy = embox.getType().dyn_cast<fir::BoxType>())
831f31795cSEric Schweitz if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
8497d8972cSEric Schweitz if (seqTy.hasConstantShape())
8597d8972cSEric Schweitz return rewriteStaticShape(embox, rewriter, seqTy);
8697d8972cSEric Schweitz return mlir::failure();
8797d8972cSEric Schweitz }
8897d8972cSEric Schweitz
rewriteStaticShape(fir::EmboxOp embox,mlir::PatternRewriter & rewriter,fir::SequenceType seqTy) const891f31795cSEric Schweitz mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox,
9097d8972cSEric Schweitz mlir::PatternRewriter &rewriter,
911f31795cSEric Schweitz fir::SequenceType seqTy) const {
9297d8972cSEric Schweitz auto loc = embox.getLoc();
9397d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shapeOpers;
9497d8972cSEric Schweitz auto idxTy = rewriter.getIndexType();
9597d8972cSEric Schweitz for (auto ext : seqTy.getShape()) {
9697d8972cSEric Schweitz auto iAttr = rewriter.getIndexAttr(ext);
97a54f4eaeSMogball auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
9897d8972cSEric Schweitz shapeOpers.push_back(extVal);
9997d8972cSEric Schweitz }
1001f31795cSEric Schweitz auto xbox = rewriter.create<fir::cg::XEmboxOp>(
101149ad3d5SShraiysh Vaishay loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None,
102149ad3d5SShraiysh Vaishay llvm::None, llvm::None, llvm::None, embox.getTypeparams());
10397d8972cSEric Schweitz LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
10497d8972cSEric Schweitz rewriter.replaceOp(embox, xbox.getOperation()->getResults());
10597d8972cSEric Schweitz return mlir::success();
10697d8972cSEric Schweitz }
10797d8972cSEric Schweitz
rewriteDynamicShape(fir::EmboxOp embox,mlir::PatternRewriter & rewriter,mlir::Value shapeVal) const1081f31795cSEric Schweitz mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox,
10997d8972cSEric Schweitz mlir::PatternRewriter &rewriter,
11097d8972cSEric Schweitz mlir::Value shapeVal) const {
11197d8972cSEric Schweitz auto loc = embox.getLoc();
11297d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shapeOpers;
11397d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shiftOpers;
1141f31795cSEric Schweitz if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) {
11597d8972cSEric Schweitz populateShape(shapeOpers, shapeOp);
11697d8972cSEric Schweitz } else {
1171f31795cSEric Schweitz auto shiftOp =
1181f31795cSEric Schweitz mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp());
11997d8972cSEric Schweitz assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
12097d8972cSEric Schweitz populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
12197d8972cSEric Schweitz }
12297d8972cSEric Schweitz llvm::SmallVector<mlir::Value> sliceOpers;
12397d8972cSEric Schweitz llvm::SmallVector<mlir::Value> subcompOpers;
1243c7ff45cSValentin Clement llvm::SmallVector<mlir::Value> substrOpers;
12597d8972cSEric Schweitz if (auto s = embox.getSlice())
1261f31795cSEric Schweitz if (auto sliceOp =
1271f31795cSEric Schweitz mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
128149ad3d5SShraiysh Vaishay sliceOpers.assign(sliceOp.getTriples().begin(),
129149ad3d5SShraiysh Vaishay sliceOp.getTriples().end());
130149ad3d5SShraiysh Vaishay subcompOpers.assign(sliceOp.getFields().begin(),
131149ad3d5SShraiysh Vaishay sliceOp.getFields().end());
132149ad3d5SShraiysh Vaishay substrOpers.assign(sliceOp.getSubstr().begin(),
133149ad3d5SShraiysh Vaishay sliceOp.getSubstr().end());
13497d8972cSEric Schweitz }
1351f31795cSEric Schweitz auto xbox = rewriter.create<fir::cg::XEmboxOp>(
136149ad3d5SShraiysh Vaishay loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
137149ad3d5SShraiysh Vaishay sliceOpers, subcompOpers, substrOpers, embox.getTypeparams());
13897d8972cSEric Schweitz LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
13997d8972cSEric Schweitz rewriter.replaceOp(embox, xbox.getOperation()->getResults());
14097d8972cSEric Schweitz return mlir::success();
14197d8972cSEric Schweitz }
14297d8972cSEric Schweitz };
14397d8972cSEric Schweitz
14497d8972cSEric Schweitz /// Convert fir.rebox to the extended form where necessary.
14597d8972cSEric Schweitz ///
14697d8972cSEric Schweitz /// For example,
14797d8972cSEric Schweitz /// ```
148a54f4eaeSMogball /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
149a54f4eaeSMogball /// !fir.box<!fir.array<?xi32>>
15097d8972cSEric Schweitz /// ```
15197d8972cSEric Schweitz /// converted to
15297d8972cSEric Schweitz /// ```
153a54f4eaeSMogball /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
154a54f4eaeSMogball /// index, index) -> !fir.box<!fir.array<?xi32>>
15597d8972cSEric Schweitz /// ```
1561f31795cSEric Schweitz class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> {
15797d8972cSEric Schweitz public:
15897d8972cSEric Schweitz using OpRewritePattern::OpRewritePattern;
15997d8972cSEric Schweitz
16097d8972cSEric Schweitz mlir::LogicalResult
matchAndRewrite(fir::ReboxOp rebox,mlir::PatternRewriter & rewriter) const1611f31795cSEric Schweitz matchAndRewrite(fir::ReboxOp rebox,
16297d8972cSEric Schweitz mlir::PatternRewriter &rewriter) const override {
16397d8972cSEric Schweitz auto loc = rebox.getLoc();
16497d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shapeOpers;
16597d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shiftOpers;
166149ad3d5SShraiysh Vaishay if (auto shapeVal = rebox.getShape()) {
1671f31795cSEric Schweitz if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
16897d8972cSEric Schweitz populateShape(shapeOpers, shapeOp);
1691f31795cSEric Schweitz else if (auto shiftOp =
1701f31795cSEric Schweitz mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
17197d8972cSEric Schweitz populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
1721f31795cSEric Schweitz else if (auto shiftOp =
1731f31795cSEric Schweitz mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
17497d8972cSEric Schweitz populateShift(shiftOpers, shiftOp);
17597d8972cSEric Schweitz else
17697d8972cSEric Schweitz return mlir::failure();
17797d8972cSEric Schweitz }
17897d8972cSEric Schweitz llvm::SmallVector<mlir::Value> sliceOpers;
17997d8972cSEric Schweitz llvm::SmallVector<mlir::Value> subcompOpers;
1803c7ff45cSValentin Clement llvm::SmallVector<mlir::Value> substrOpers;
181149ad3d5SShraiysh Vaishay if (auto s = rebox.getSlice())
1821f31795cSEric Schweitz if (auto sliceOp =
1831f31795cSEric Schweitz mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
184149ad3d5SShraiysh Vaishay sliceOpers.append(sliceOp.getTriples().begin(),
185149ad3d5SShraiysh Vaishay sliceOp.getTriples().end());
186149ad3d5SShraiysh Vaishay subcompOpers.append(sliceOp.getFields().begin(),
187149ad3d5SShraiysh Vaishay sliceOp.getFields().end());
188149ad3d5SShraiysh Vaishay substrOpers.append(sliceOp.getSubstr().begin(),
189149ad3d5SShraiysh Vaishay sliceOp.getSubstr().end());
19097d8972cSEric Schweitz }
19197d8972cSEric Schweitz
1921f31795cSEric Schweitz auto xRebox = rewriter.create<fir::cg::XReboxOp>(
193149ad3d5SShraiysh Vaishay loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
194149ad3d5SShraiysh Vaishay sliceOpers, subcompOpers, substrOpers);
19597d8972cSEric Schweitz LLVM_DEBUG(llvm::dbgs()
19697d8972cSEric Schweitz << "rewriting " << rebox << " to " << xRebox << '\n');
19797d8972cSEric Schweitz rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
19897d8972cSEric Schweitz return mlir::success();
19997d8972cSEric Schweitz }
20097d8972cSEric Schweitz };
20197d8972cSEric Schweitz
20297d8972cSEric Schweitz /// Convert all fir.array_coor to the extended form.
20397d8972cSEric Schweitz ///
20497d8972cSEric Schweitz /// For example,
20597d8972cSEric Schweitz /// ```
206a54f4eaeSMogball /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
207a54f4eaeSMogball /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
20897d8972cSEric Schweitz /// ```
20997d8972cSEric Schweitz /// converted to
21097d8972cSEric Schweitz /// ```
211a54f4eaeSMogball /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
212a54f4eaeSMogball /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
213a54f4eaeSMogball /// !fir.ref<i32>
21497d8972cSEric Schweitz /// ```
2151f31795cSEric Schweitz class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
21697d8972cSEric Schweitz public:
21797d8972cSEric Schweitz using OpRewritePattern::OpRewritePattern;
21897d8972cSEric Schweitz
21997d8972cSEric Schweitz mlir::LogicalResult
matchAndRewrite(fir::ArrayCoorOp arrCoor,mlir::PatternRewriter & rewriter) const2201f31795cSEric Schweitz matchAndRewrite(fir::ArrayCoorOp arrCoor,
22197d8972cSEric Schweitz mlir::PatternRewriter &rewriter) const override {
22297d8972cSEric Schweitz auto loc = arrCoor.getLoc();
22397d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shapeOpers;
22497d8972cSEric Schweitz llvm::SmallVector<mlir::Value> shiftOpers;
225149ad3d5SShraiysh Vaishay if (auto shapeVal = arrCoor.getShape()) {
2261f31795cSEric Schweitz if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
22797d8972cSEric Schweitz populateShape(shapeOpers, shapeOp);
2281f31795cSEric Schweitz else if (auto shiftOp =
2291f31795cSEric Schweitz mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
23097d8972cSEric Schweitz populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
2311f31795cSEric Schweitz else if (auto shiftOp =
2321f31795cSEric Schweitz mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
23397d8972cSEric Schweitz populateShift(shiftOpers, shiftOp);
23497d8972cSEric Schweitz else
23597d8972cSEric Schweitz return mlir::failure();
23697d8972cSEric Schweitz }
23797d8972cSEric Schweitz llvm::SmallVector<mlir::Value> sliceOpers;
23897d8972cSEric Schweitz llvm::SmallVector<mlir::Value> subcompOpers;
239149ad3d5SShraiysh Vaishay if (auto s = arrCoor.getSlice())
2401f31795cSEric Schweitz if (auto sliceOp =
2411f31795cSEric Schweitz mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
242149ad3d5SShraiysh Vaishay sliceOpers.append(sliceOp.getTriples().begin(),
243149ad3d5SShraiysh Vaishay sliceOp.getTriples().end());
244149ad3d5SShraiysh Vaishay subcompOpers.append(sliceOp.getFields().begin(),
245149ad3d5SShraiysh Vaishay sliceOp.getFields().end());
246149ad3d5SShraiysh Vaishay assert(sliceOp.getSubstr().empty() &&
2473c7ff45cSValentin Clement "Don't allow substring operations on array_coor. This "
2483c7ff45cSValentin Clement "restriction may be lifted in the future.");
24997d8972cSEric Schweitz }
2501f31795cSEric Schweitz auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>(
251149ad3d5SShraiysh Vaishay loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
252149ad3d5SShraiysh Vaishay sliceOpers, subcompOpers, arrCoor.getIndices(),
253149ad3d5SShraiysh Vaishay arrCoor.getTypeparams());
25497d8972cSEric Schweitz LLVM_DEBUG(llvm::dbgs()
25597d8972cSEric Schweitz << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
25697d8972cSEric Schweitz rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
25797d8972cSEric Schweitz return mlir::success();
25897d8972cSEric Schweitz }
25997d8972cSEric Schweitz };
26097d8972cSEric Schweitz
2611f31795cSEric Schweitz class CodeGenRewrite : public fir::CodeGenRewriteBase<CodeGenRewrite> {
26297d8972cSEric Schweitz public:
runOn(mlir::Operation * op,mlir::Region & region)263*b09426ffSValentin Clement void runOn(mlir::Operation *op, mlir::Region ®ion) {
26497d8972cSEric Schweitz auto &context = getContext();
26597d8972cSEric Schweitz mlir::OpBuilder rewriter(&context);
26697d8972cSEric Schweitz mlir::ConversionTarget target(context);
2671f31795cSEric Schweitz target.addLegalDialect<mlir::arith::ArithmeticDialect, fir::FIROpsDialect,
2681f31795cSEric Schweitz fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
2691f31795cSEric Schweitz target.addIllegalOp<fir::ArrayCoorOp>();
2701f31795cSEric Schweitz target.addIllegalOp<fir::ReboxOp>();
2711f31795cSEric Schweitz target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
2721f31795cSEric Schweitz return !(embox.getShape() || embox.getType()
2731f31795cSEric Schweitz .cast<fir::BoxType>()
2741f31795cSEric Schweitz .getEleTy()
2751f31795cSEric Schweitz .isa<fir::SequenceType>());
27697d8972cSEric Schweitz });
2779f85c198SRiver Riddle mlir::RewritePatternSet patterns(&context);
27897d8972cSEric Schweitz patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
27997d8972cSEric Schweitz &context);
28097d8972cSEric Schweitz if (mlir::failed(
28197d8972cSEric Schweitz mlir::applyPartialConversion(op, target, std::move(patterns)))) {
28297d8972cSEric Schweitz mlir::emitError(mlir::UnknownLoc::get(&context),
28397d8972cSEric Schweitz "error in running the pre-codegen conversions");
28497d8972cSEric Schweitz signalPassFailure();
28597d8972cSEric Schweitz }
286*b09426ffSValentin Clement // Erase any residual.
287*b09426ffSValentin Clement simplifyRegion(region);
28897d8972cSEric Schweitz }
289*b09426ffSValentin Clement
runOnOperation()290*b09426ffSValentin Clement void runOnOperation() override final {
291*b09426ffSValentin Clement // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp.
292*b09426ffSValentin Clement auto mod = getOperation();
293*b09426ffSValentin Clement for (auto func : mod.getOps<mlir::func::FuncOp>())
294*b09426ffSValentin Clement runOn(func, func.getBody());
295*b09426ffSValentin Clement for (auto global : mod.getOps<fir::GlobalOp>())
296*b09426ffSValentin Clement runOn(global, global.getRegion());
297*b09426ffSValentin Clement }
298*b09426ffSValentin Clement
299*b09426ffSValentin Clement // Clean up the region.
simplifyRegion(mlir::Region & region)300*b09426ffSValentin Clement void simplifyRegion(mlir::Region ®ion) {
301*b09426ffSValentin Clement for (auto &block : region.getBlocks())
302*b09426ffSValentin Clement for (auto &op : block.getOperations()) {
303*b09426ffSValentin Clement for (auto ® : op.getRegions())
304*b09426ffSValentin Clement simplifyRegion(reg);
305*b09426ffSValentin Clement maybeEraseOp(&op);
306*b09426ffSValentin Clement }
307*b09426ffSValentin Clement doDCE();
308*b09426ffSValentin Clement }
309*b09426ffSValentin Clement
310*b09426ffSValentin Clement /// Run a simple DCE cleanup to remove any dead code after the rewrites.
doDCE()311*b09426ffSValentin Clement void doDCE() {
312*b09426ffSValentin Clement std::vector<mlir::Operation *> workList;
313*b09426ffSValentin Clement workList.swap(opsToErase);
314*b09426ffSValentin Clement while (!workList.empty()) {
315*b09426ffSValentin Clement for (auto *op : workList) {
316*b09426ffSValentin Clement std::vector<mlir::Value> opOperands(op->operand_begin(),
317*b09426ffSValentin Clement op->operand_end());
318*b09426ffSValentin Clement LLVM_DEBUG(llvm::dbgs() << "DCE on " << *op << '\n');
319*b09426ffSValentin Clement ++numDCE;
320*b09426ffSValentin Clement op->erase();
321*b09426ffSValentin Clement for (auto opnd : opOperands)
322*b09426ffSValentin Clement maybeEraseOp(opnd.getDefiningOp());
323*b09426ffSValentin Clement }
324*b09426ffSValentin Clement workList.clear();
325*b09426ffSValentin Clement workList.swap(opsToErase);
326*b09426ffSValentin Clement }
327*b09426ffSValentin Clement }
328*b09426ffSValentin Clement
maybeEraseOp(mlir::Operation * op)329*b09426ffSValentin Clement void maybeEraseOp(mlir::Operation *op) {
330*b09426ffSValentin Clement if (!op)
331*b09426ffSValentin Clement return;
332*b09426ffSValentin Clement if (op->hasTrait<mlir::OpTrait::IsTerminator>())
333*b09426ffSValentin Clement return;
334*b09426ffSValentin Clement if (mlir::isOpTriviallyDead(op))
335*b09426ffSValentin Clement opsToErase.push_back(op);
336*b09426ffSValentin Clement }
337*b09426ffSValentin Clement
338*b09426ffSValentin Clement private:
339*b09426ffSValentin Clement std::vector<mlir::Operation *> opsToErase;
34097d8972cSEric Schweitz };
34197d8972cSEric Schweitz
34297d8972cSEric Schweitz } // namespace
34397d8972cSEric Schweitz
createFirCodeGenRewritePass()34497d8972cSEric Schweitz std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
34597d8972cSEric Schweitz return std::make_unique<CodeGenRewrite>();
34697d8972cSEric Schweitz }
347