1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CGOps.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/CodeGen/CodeGen.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROps.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/Support/FIRContext.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 
24 //===----------------------------------------------------------------------===//
25 // Codegen rewrite: rewriting of subgraphs of ops
26 //===----------------------------------------------------------------------===//
27 
28 using namespace fir;
29 using namespace mlir;
30 
31 #define DEBUG_TYPE "flang-codegen-rewrite"
32 
33 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
34                           ShapeOp shape) {
35   vec.append(shape.getExtents().begin(), shape.getExtents().end());
36 }
37 
38 // Operands of fir.shape_shift split into two vectors.
39 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
40                                   llvm::SmallVectorImpl<mlir::Value> &shiftVec,
41                                   ShapeShiftOp shift) {
42   auto endIter = shift.getPairs().end();
43   for (auto i = shift.getPairs().begin(); i != endIter;) {
44     shiftVec.push_back(*i++);
45     shapeVec.push_back(*i++);
46   }
47 }
48 
49 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
50                           ShiftOp shift) {
51   vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
52 }
53 
54 namespace {
55 
56 /// Convert fir.embox to the extended form where necessary.
57 ///
58 /// The embox operation can take arguments that specify multidimensional array
59 /// properties at runtime. These properties may be shared between distinct
60 /// objects that have the same properties. Before we lower these small DAGs to
61 /// LLVM-IR, we gather all the information into a single extended operation. For
62 /// example,
63 /// ```
64 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
65 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
66 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
67 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
68 /// ```
69 /// can be rewritten as
70 /// ```
71 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
72 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
73 /// !fir.box<!fir.array<?xi32>>
74 /// ```
75 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
76 public:
77   using OpRewritePattern::OpRewritePattern;
78 
79   mlir::LogicalResult
80   matchAndRewrite(EmboxOp embox,
81                   mlir::PatternRewriter &rewriter) const override {
82     auto shapeVal = embox.getShape();
83     // If the embox does not include a shape, then do not convert it
84     if (shapeVal)
85       return rewriteDynamicShape(embox, rewriter, shapeVal);
86     if (auto boxTy = embox.getType().dyn_cast<BoxType>())
87       if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>())
88         if (seqTy.hasConstantShape())
89           return rewriteStaticShape(embox, rewriter, seqTy);
90     return mlir::failure();
91   }
92 
93   mlir::LogicalResult rewriteStaticShape(EmboxOp embox,
94                                          mlir::PatternRewriter &rewriter,
95                                          SequenceType seqTy) const {
96     auto loc = embox.getLoc();
97     llvm::SmallVector<mlir::Value> shapeOpers;
98     auto idxTy = rewriter.getIndexType();
99     for (auto ext : seqTy.getShape()) {
100       auto iAttr = rewriter.getIndexAttr(ext);
101       auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
102       shapeOpers.push_back(extVal);
103     }
104     auto xbox = rewriter.create<cg::XEmboxOp>(
105         loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None,
106         llvm::None, llvm::None, llvm::None, embox.getTypeparams());
107     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
108     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
109     return mlir::success();
110   }
111 
112   mlir::LogicalResult rewriteDynamicShape(EmboxOp embox,
113                                           mlir::PatternRewriter &rewriter,
114                                           mlir::Value shapeVal) const {
115     auto loc = embox.getLoc();
116     auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp());
117     llvm::SmallVector<mlir::Value> shapeOpers;
118     llvm::SmallVector<mlir::Value> shiftOpers;
119     if (shapeOp) {
120       populateShape(shapeOpers, shapeOp);
121     } else {
122       auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp());
123       assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
124       populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
125     }
126     llvm::SmallVector<mlir::Value> sliceOpers;
127     llvm::SmallVector<mlir::Value> subcompOpers;
128     llvm::SmallVector<mlir::Value> substrOpers;
129     if (auto s = embox.getSlice())
130       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
131         sliceOpers.assign(sliceOp.getTriples().begin(),
132                           sliceOp.getTriples().end());
133         subcompOpers.assign(sliceOp.getFields().begin(),
134                             sliceOp.getFields().end());
135         substrOpers.assign(sliceOp.getSubstr().begin(),
136                            sliceOp.getSubstr().end());
137       }
138     auto xbox = rewriter.create<cg::XEmboxOp>(
139         loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
140         sliceOpers, subcompOpers, substrOpers, embox.getTypeparams());
141     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
142     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
143     return mlir::success();
144   }
145 };
146 
147 /// Convert fir.rebox to the extended form where necessary.
148 ///
149 /// For example,
150 /// ```
151 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
152 /// !fir.box<!fir.array<?xi32>>
153 /// ```
154 /// converted to
155 /// ```
156 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
157 /// index, index) -> !fir.box<!fir.array<?xi32>>
158 /// ```
159 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
160 public:
161   using OpRewritePattern::OpRewritePattern;
162 
163   mlir::LogicalResult
164   matchAndRewrite(ReboxOp rebox,
165                   mlir::PatternRewriter &rewriter) const override {
166     auto loc = rebox.getLoc();
167     llvm::SmallVector<mlir::Value> shapeOpers;
168     llvm::SmallVector<mlir::Value> shiftOpers;
169     if (auto shapeVal = rebox.getShape()) {
170       if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
171         populateShape(shapeOpers, shapeOp);
172       else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
173         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
174       else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
175         populateShift(shiftOpers, shiftOp);
176       else
177         return mlir::failure();
178     }
179     llvm::SmallVector<mlir::Value> sliceOpers;
180     llvm::SmallVector<mlir::Value> subcompOpers;
181     llvm::SmallVector<mlir::Value> substrOpers;
182     if (auto s = rebox.getSlice())
183       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
184         sliceOpers.append(sliceOp.getTriples().begin(),
185                           sliceOp.getTriples().end());
186         subcompOpers.append(sliceOp.getFields().begin(),
187                             sliceOp.getFields().end());
188         substrOpers.append(sliceOp.getSubstr().begin(),
189                            sliceOp.getSubstr().end());
190       }
191 
192     auto xRebox = rewriter.create<cg::XReboxOp>(
193         loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
194         sliceOpers, subcompOpers, substrOpers);
195     LLVM_DEBUG(llvm::dbgs()
196                << "rewriting " << rebox << " to " << xRebox << '\n');
197     rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
198     return mlir::success();
199   }
200 };
201 
202 /// Convert all fir.array_coor to the extended form.
203 ///
204 /// For example,
205 /// ```
206 ///  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
207 ///  !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
208 /// ```
209 /// converted to
210 /// ```
211 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
212 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
213 /// !fir.ref<i32>
214 /// ```
215 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
216 public:
217   using OpRewritePattern::OpRewritePattern;
218 
219   mlir::LogicalResult
220   matchAndRewrite(ArrayCoorOp arrCoor,
221                   mlir::PatternRewriter &rewriter) const override {
222     auto loc = arrCoor.getLoc();
223     llvm::SmallVector<mlir::Value> shapeOpers;
224     llvm::SmallVector<mlir::Value> shiftOpers;
225     if (auto shapeVal = arrCoor.getShape()) {
226       if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
227         populateShape(shapeOpers, shapeOp);
228       else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
229         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
230       else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
231         populateShift(shiftOpers, shiftOp);
232       else
233         return mlir::failure();
234     }
235     llvm::SmallVector<mlir::Value> sliceOpers;
236     llvm::SmallVector<mlir::Value> subcompOpers;
237     if (auto s = arrCoor.getSlice())
238       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
239         sliceOpers.append(sliceOp.getTriples().begin(),
240                           sliceOp.getTriples().end());
241         subcompOpers.append(sliceOp.getFields().begin(),
242                             sliceOp.getFields().end());
243         assert(sliceOp.getSubstr().empty() &&
244                "Don't allow substring operations on array_coor. This "
245                "restriction may be lifted in the future.");
246       }
247     auto xArrCoor = rewriter.create<cg::XArrayCoorOp>(
248         loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
249         sliceOpers, subcompOpers, arrCoor.getIndices(),
250         arrCoor.getTypeparams());
251     LLVM_DEBUG(llvm::dbgs()
252                << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
253     rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
254     return mlir::success();
255   }
256 };
257 
258 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> {
259 public:
260   void runOnOperation() override final {
261     auto op = getOperation();
262     auto &context = getContext();
263     mlir::OpBuilder rewriter(&context);
264     mlir::ConversionTarget target(context);
265     target.addLegalDialect<mlir::arith::ArithmeticDialect, FIROpsDialect,
266                            FIRCodeGenDialect, mlir::func::FuncDialect>();
267     target.addIllegalOp<ArrayCoorOp>();
268     target.addIllegalOp<ReboxOp>();
269     target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) {
270       return !(embox.getShape() ||
271                embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>());
272     });
273     mlir::RewritePatternSet patterns(&context);
274     patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
275         &context);
276     if (mlir::failed(
277             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
278       mlir::emitError(mlir::UnknownLoc::get(&context),
279                       "error in running the pre-codegen conversions");
280       signalPassFailure();
281     }
282   }
283 };
284 
285 } // namespace
286 
287 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
288   return std::make_unique<CodeGenRewrite>();
289 }
290