1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CGOps.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/CodeGen/CodeGen.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROps.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/Support/FIRContext.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23 
24 //===----------------------------------------------------------------------===//
25 // Codegen rewrite: rewriting of subgraphs of ops
26 //===----------------------------------------------------------------------===//
27 
28 using namespace fir;
29 
30 #define DEBUG_TYPE "flang-codegen-rewrite"
31 
32 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
33                           ShapeOp shape) {
34   vec.append(shape.getExtents().begin(), shape.getExtents().end());
35 }
36 
37 // Operands of fir.shape_shift split into two vectors.
38 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
39                                   llvm::SmallVectorImpl<mlir::Value> &shiftVec,
40                                   ShapeShiftOp shift) {
41   auto endIter = shift.getPairs().end();
42   for (auto i = shift.getPairs().begin(); i != endIter;) {
43     shiftVec.push_back(*i++);
44     shapeVec.push_back(*i++);
45   }
46 }
47 
48 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
49                           ShiftOp shift) {
50   vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
51 }
52 
53 namespace {
54 
55 /// Convert fir.embox to the extended form where necessary.
56 ///
57 /// The embox operation can take arguments that specify multidimensional array
58 /// properties at runtime. These properties may be shared between distinct
59 /// objects that have the same properties. Before we lower these small DAGs to
60 /// LLVM-IR, we gather all the information into a single extended operation. For
61 /// example,
62 /// ```
63 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
64 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
65 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
66 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
67 /// ```
68 /// can be rewritten as
69 /// ```
70 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
71 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
72 /// !fir.box<!fir.array<?xi32>>
73 /// ```
74 class EmboxConversion : public mlir::OpRewritePattern<EmboxOp> {
75 public:
76   using OpRewritePattern::OpRewritePattern;
77 
78   mlir::LogicalResult
79   matchAndRewrite(EmboxOp embox,
80                   mlir::PatternRewriter &rewriter) const override {
81     auto shapeVal = embox.getShape();
82     // If the embox does not include a shape, then do not convert it
83     if (shapeVal)
84       return rewriteDynamicShape(embox, rewriter, shapeVal);
85     if (auto boxTy = embox.getType().dyn_cast<BoxType>())
86       if (auto seqTy = boxTy.getEleTy().dyn_cast<SequenceType>())
87         if (seqTy.hasConstantShape())
88           return rewriteStaticShape(embox, rewriter, seqTy);
89     return mlir::failure();
90   }
91 
92   mlir::LogicalResult rewriteStaticShape(EmboxOp embox,
93                                          mlir::PatternRewriter &rewriter,
94                                          SequenceType seqTy) const {
95     auto loc = embox.getLoc();
96     llvm::SmallVector<mlir::Value> shapeOpers;
97     auto idxTy = rewriter.getIndexType();
98     for (auto ext : seqTy.getShape()) {
99       auto iAttr = rewriter.getIndexAttr(ext);
100       auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
101       shapeOpers.push_back(extVal);
102     }
103     auto xbox = rewriter.create<cg::XEmboxOp>(
104         loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None,
105         llvm::None, llvm::None, llvm::None, embox.getTypeparams());
106     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
107     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
108     return mlir::success();
109   }
110 
111   mlir::LogicalResult rewriteDynamicShape(EmboxOp embox,
112                                           mlir::PatternRewriter &rewriter,
113                                           mlir::Value shapeVal) const {
114     auto loc = embox.getLoc();
115     auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp());
116     llvm::SmallVector<mlir::Value> shapeOpers;
117     llvm::SmallVector<mlir::Value> shiftOpers;
118     if (shapeOp) {
119       populateShape(shapeOpers, shapeOp);
120     } else {
121       auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp());
122       assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
123       populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
124     }
125     llvm::SmallVector<mlir::Value> sliceOpers;
126     llvm::SmallVector<mlir::Value> subcompOpers;
127     llvm::SmallVector<mlir::Value> substrOpers;
128     if (auto s = embox.getSlice())
129       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
130         sliceOpers.assign(sliceOp.getTriples().begin(),
131                           sliceOp.getTriples().end());
132         subcompOpers.assign(sliceOp.getFields().begin(),
133                             sliceOp.getFields().end());
134         substrOpers.assign(sliceOp.getSubstr().begin(),
135                            sliceOp.getSubstr().end());
136       }
137     auto xbox = rewriter.create<cg::XEmboxOp>(
138         loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
139         sliceOpers, subcompOpers, substrOpers, embox.getTypeparams());
140     LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
141     rewriter.replaceOp(embox, xbox.getOperation()->getResults());
142     return mlir::success();
143   }
144 };
145 
146 /// Convert fir.rebox to the extended form where necessary.
147 ///
148 /// For example,
149 /// ```
150 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
151 /// !fir.box<!fir.array<?xi32>>
152 /// ```
153 /// converted to
154 /// ```
155 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
156 /// index, index) -> !fir.box<!fir.array<?xi32>>
157 /// ```
158 class ReboxConversion : public mlir::OpRewritePattern<ReboxOp> {
159 public:
160   using OpRewritePattern::OpRewritePattern;
161 
162   mlir::LogicalResult
163   matchAndRewrite(ReboxOp rebox,
164                   mlir::PatternRewriter &rewriter) const override {
165     auto loc = rebox.getLoc();
166     llvm::SmallVector<mlir::Value> shapeOpers;
167     llvm::SmallVector<mlir::Value> shiftOpers;
168     if (auto shapeVal = rebox.getShape()) {
169       if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
170         populateShape(shapeOpers, shapeOp);
171       else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
172         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
173       else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
174         populateShift(shiftOpers, shiftOp);
175       else
176         return mlir::failure();
177     }
178     llvm::SmallVector<mlir::Value> sliceOpers;
179     llvm::SmallVector<mlir::Value> subcompOpers;
180     llvm::SmallVector<mlir::Value> substrOpers;
181     if (auto s = rebox.getSlice())
182       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
183         sliceOpers.append(sliceOp.getTriples().begin(),
184                           sliceOp.getTriples().end());
185         subcompOpers.append(sliceOp.getFields().begin(),
186                             sliceOp.getFields().end());
187         substrOpers.append(sliceOp.getSubstr().begin(),
188                            sliceOp.getSubstr().end());
189       }
190 
191     auto xRebox = rewriter.create<cg::XReboxOp>(
192         loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
193         sliceOpers, subcompOpers, substrOpers);
194     LLVM_DEBUG(llvm::dbgs()
195                << "rewriting " << rebox << " to " << xRebox << '\n');
196     rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
197     return mlir::success();
198   }
199 };
200 
201 /// Convert all fir.array_coor to the extended form.
202 ///
203 /// For example,
204 /// ```
205 ///  %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
206 ///  !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
207 /// ```
208 /// converted to
209 /// ```
210 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
211 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
212 /// !fir.ref<i32>
213 /// ```
214 class ArrayCoorConversion : public mlir::OpRewritePattern<ArrayCoorOp> {
215 public:
216   using OpRewritePattern::OpRewritePattern;
217 
218   mlir::LogicalResult
219   matchAndRewrite(ArrayCoorOp arrCoor,
220                   mlir::PatternRewriter &rewriter) const override {
221     auto loc = arrCoor.getLoc();
222     llvm::SmallVector<mlir::Value> shapeOpers;
223     llvm::SmallVector<mlir::Value> shiftOpers;
224     if (auto shapeVal = arrCoor.getShape()) {
225       if (auto shapeOp = dyn_cast<ShapeOp>(shapeVal.getDefiningOp()))
226         populateShape(shapeOpers, shapeOp);
227       else if (auto shiftOp = dyn_cast<ShapeShiftOp>(shapeVal.getDefiningOp()))
228         populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
229       else if (auto shiftOp = dyn_cast<ShiftOp>(shapeVal.getDefiningOp()))
230         populateShift(shiftOpers, shiftOp);
231       else
232         return mlir::failure();
233     }
234     llvm::SmallVector<mlir::Value> sliceOpers;
235     llvm::SmallVector<mlir::Value> subcompOpers;
236     if (auto s = arrCoor.getSlice())
237       if (auto sliceOp = dyn_cast_or_null<SliceOp>(s.getDefiningOp())) {
238         sliceOpers.append(sliceOp.getTriples().begin(),
239                           sliceOp.getTriples().end());
240         subcompOpers.append(sliceOp.getFields().begin(),
241                             sliceOp.getFields().end());
242         assert(sliceOp.getSubstr().empty() &&
243                "Don't allow substring operations on array_coor. This "
244                "restriction may be lifted in the future.");
245       }
246     auto xArrCoor = rewriter.create<cg::XArrayCoorOp>(
247         loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
248         sliceOpers, subcompOpers, arrCoor.getIndices(),
249         arrCoor.getTypeparams());
250     LLVM_DEBUG(llvm::dbgs()
251                << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
252     rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
253     return mlir::success();
254   }
255 };
256 
257 class CodeGenRewrite : public CodeGenRewriteBase<CodeGenRewrite> {
258 public:
259   void runOnOperation() override final {
260     auto op = getOperation();
261     auto &context = getContext();
262     mlir::OpBuilder rewriter(&context);
263     mlir::ConversionTarget target(context);
264     target.addLegalDialect<mlir::arith::ArithmeticDialect, FIROpsDialect,
265                            FIRCodeGenDialect, mlir::func::FuncDialect>();
266     target.addIllegalOp<ArrayCoorOp>();
267     target.addIllegalOp<ReboxOp>();
268     target.addDynamicallyLegalOp<EmboxOp>([](EmboxOp embox) {
269       return !(embox.getShape() ||
270                embox.getType().cast<BoxType>().getEleTy().isa<SequenceType>());
271     });
272     mlir::RewritePatternSet patterns(&context);
273     patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
274         &context);
275     if (mlir::failed(
276             mlir::applyPartialConversion(op, target, std::move(patterns)))) {
277       mlir::emitError(mlir::UnknownLoc::get(&context),
278                       "error in running the pre-codegen conversions");
279       signalPassFailure();
280     }
281   }
282 };
283 
284 } // namespace
285 
286 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
287   return std::make_unique<CodeGenRewrite>();
288 }
289