1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CGOps.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/CodeGen/CodeGen.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROps.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/Support/FIRContext.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 //===----------------------------------------------------------------------===//
24 // Codegen rewrite: rewriting of subgraphs of ops
25 //===----------------------------------------------------------------------===//
26 
27 using namespace fir;
28 
29 #define DEBUG_TYPE "flang-codegen-rewrite"
30 
31 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
32                           ShapeOp shape) {
33   vec.append(shape.extents().begin(), shape.extents().end());
34 }
35 
36 // Operands of fir.shape_shift split into two vectors.
37 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
38                                   llvm::SmallVectorImpl<mlir::Value> &shiftVec,
39                                   ShapeShiftOp shift) {
40   auto endIter = shift.pairs().end();
41   for (auto i = shift.pairs().begin(); i != endIter;) {
42     shiftVec.push_back(*i++);
43     shapeVec.push_back(*i++);
44   }
45 }
46 
47 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
48                           ShiftOp shift) {
49   vec.append(shift.origins().begin(), shift.origins().end());
50 }
51 
52 namespace {
53 
54 /// Convert fir.embox to the extended form where necessary.
55 ///
56 /// The embox operation can take arguments that specify multidimensional array
57 /// properties at runtime. These properties may be shared between distinct
58 /// objects that have the same properties. Before we lower these small DAGs to
59 /// LLVM-IR, we gather all the information into a single extended operation. For
60 /// example,
61 /// ```
62 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
63 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
64 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
65 /// ```
66 /// can be rewritten as
67 /// ```
68 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) -> !fir.box<!fir.array<?xi32>>
69 /// ```
70 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
71 public:
72   using OpRewritePattern::OpRewritePattern;
73 
74   mlir::LogicalResult
75   matchAndRewrite(EmboxOp embox,
76                   mlir::PatternRewriter &rewriter) const override {
77     auto shapeVal = embox.getShape();
78     // If the embox does not include a shape, then do not convert it
79     if (shapeVal)
80       return rewriteDynamicShape(embox, rewriter, shapeVal);
81     if (auto boxTy = embox.getType().dyn_cast<BoxType>())
82       if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>())
83         if (seqTy.hasConstantShape())
84           return rewriteStaticShape(embox, rewriter, seqTy);
85     return mlir::failure();
86   }
87 
88   mlir::LogicalResult rewriteStaticShape(EmboxOp embox,
89                                          mlir::PatternRewriter &rewriter,
90                                          SequenceType seqTy) const {
91     auto loc = embox.getLoc();
92     llvm::SmallVector<mlir::Value> shapeOpers;
93     auto idxTy = rewriter.getIndexType();
94     for (auto ext : seqTy.getShape()) {
95       auto iAttr = rewriter.getIndexAttr(ext);
96       auto extVal = rewriter.create<mlir::ConstantOp>(loc, idxTy, iAttr);
97       shapeOpers.push_back(extVal);
98     }
99     auto xbox = rewriter.create<cg::XEmboxOp>(
100         loc, embox.getType(), embox.memref(), shapeOpers, llvm::None,
101         llvm::None, llvm::None, embox.lenParams());
102     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
103     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
104     return mlir::success();
105   }
106 
107   mlir::LogicalResult rewriteDynamicShape(EmboxOp embox,
108                                           mlir::PatternRewriter &rewriter,
109                                           mlir::Value shapeVal) const {
110     auto loc = embox.getLoc();
111     auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp());
112     llvm::SmallVector<mlir::Value> shapeOpers;
113     llvm::SmallVector<mlir::Value> shiftOpers;
114     if (shapeOp) {
115       populateShape(shapeOpers, shapeOp);
116     } else {
117       auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp());
118       assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
119       populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
120     }
121     llvm::SmallVector<mlir::Value> sliceOpers;
122     llvm::SmallVector<mlir::Value> subcompOpers;
123     if (auto s = embox.getSlice())
124       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
125         sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
126         subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
127       }
128     auto xbox = rewriter.create<cg::XEmboxOp>(
129         loc, embox.getType(), embox.memref(), shapeOpers, shiftOpers,
130         sliceOpers, subcompOpers, embox.lenParams());
131     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
132     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
133     return mlir::success();
134   }
135 };
136 
137 /// Convert fir.rebox to the extended form where necessary.
138 ///
139 /// For example,
140 /// ```
141 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<?xi32>>
142 /// ```
143 /// converted to
144 /// ```
145 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>, index, index) -> !fir.box<!fir.array<?xi32>>
146 /// ```
147 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
148 public:
149   using OpRewritePattern::OpRewritePattern;
150 
151   mlir::LogicalResult
152   matchAndRewrite(ReboxOp rebox,
153                   mlir::PatternRewriter &rewriter) const override {
154     auto loc = rebox.getLoc();
155     llvm::SmallVector<mlir::Value> shapeOpers;
156     llvm::SmallVector<mlir::Value> shiftOpers;
157     if (auto shapeVal = rebox.shape()) {
158       if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
159         populateShape(shapeOpers, shapeOp);
160       else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
161         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
162       else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
163         populateShift(shiftOpers, shiftOp);
164       else
165         return mlir::failure();
166     }
167     llvm::SmallVector<mlir::Value> sliceOpers;
168     llvm::SmallVector<mlir::Value> subcompOpers;
169     if (auto s = rebox.slice())
170       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
171         sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
172         subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
173       }
174 
175     auto xRebox = rewriter.create<cg::XReboxOp>(
176         loc, rebox.getType(), rebox.box(), shapeOpers, shiftOpers, sliceOpers,
177         subcompOpers);
178     LLVM_DEBUG(llvm::dbgs()
179                << "rewriting " << rebox << " to " << xRebox << '\n');
180     rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
181     return mlir::success();
182   }
183 };
184 
185 /// Convert all fir.array_coor to the extended form.
186 ///
187 /// For example,
188 /// ```
189 ///  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>, !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
190 /// ```
191 /// converted to
192 /// ```
193 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> : (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) -> !fir.ref<i32>
194 /// ```
195 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
196 public:
197   using OpRewritePattern::OpRewritePattern;
198 
199   mlir::LogicalResult
200   matchAndRewrite(ArrayCoorOp arrCoor,
201                   mlir::PatternRewriter &rewriter) const override {
202     auto loc = arrCoor.getLoc();
203     llvm::SmallVector<mlir::Value> shapeOpers;
204     llvm::SmallVector<mlir::Value> shiftOpers;
205     if (auto shapeVal = arrCoor.shape()) {
206       if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
207         populateShape(shapeOpers, shapeOp);
208       else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
209         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
210       else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
211         populateShift(shiftOpers, shiftOp);
212       else
213         return mlir::failure();
214     }
215     llvm::SmallVector<mlir::Value> sliceOpers;
216     llvm::SmallVector<mlir::Value> subcompOpers;
217     if (auto s = arrCoor.slice())
218       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
219         sliceOpers.append(sliceOp.triples().begin(), sliceOp.triples().end());
220         subcompOpers.append(sliceOp.fields().begin(), sliceOp.fields().end());
221       }
222     auto xArrCoor = rewriter.create<cg::XArrayCoorOp>(
223         loc, arrCoor.getType(), arrCoor.memref(), shapeOpers, shiftOpers,
224         sliceOpers, subcompOpers, arrCoor.indices(), arrCoor.lenParams());
225     LLVM_DEBUG(llvm::dbgs()
226                << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
227     rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
228     return mlir::success();
229   }
230 };
231 
232 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> {
233 public:
234   void runOnOperation() override final {
235     auto op = getOperation();
236     auto &context = getContext();
237     mlir::OpBuilder rewriter(&context);
238     mlir::ConversionTarget target(context);
239     target.addLegalDialect<FIROpsDialect, FIRCodeGenDialect,
240                            mlir::StandardOpsDialect>();
241     target.addIllegalOp<ArrayCoorOp>();
242     target.addIllegalOp<ReboxOp>();
243     target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) {
244       return !(embox.getShape() ||
245                embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>());
246     });
247     mlir::OwningRewritePatternList patterns(&context);
248     patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
249         &context);
250     if (mlir::failed(
251             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
252       mlir::emitError(mlir::UnknownLoc::get(&context),
253                       "error in running the pre-codegen conversions");
254       signalPassFailure();
255     }
256   }
257 };
258 
259 } // namespace
260 
261 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
262   return std::make_unique<CodeGenRewrite>();
263 }
264