1 //===-- PreCGRewrite.cpp --------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CGOps.h"
14 #include "PassDetail.h"
15 #include "flang/Optimizer/CodeGen/CodeGen.h"
16 #include "flang/Optimizer/Dialect/FIRDialect.h"
17 #include "flang/Optimizer/Dialect/FIROps.h"
18 #include "flang/Optimizer/Dialect/FIRType.h"
19 #include "flang/Optimizer/Support/FIRContext.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/Debug.h"
23
24 //===----------------------------------------------------------------------===//
25 // Codegen rewrite: rewriting of subgraphs of ops
26 //===----------------------------------------------------------------------===//
27
28 #define DEBUG_TYPE "flang-codegen-rewrite"
29
populateShape(llvm::SmallVectorImpl<mlir::Value> & vec,fir::ShapeOp shape)30 static void populateShape(llvm::SmallVectorImpl<mlir::Value> &vec,
31 fir::ShapeOp shape) {
32 vec.append(shape.getExtents().begin(), shape.getExtents().end());
33 }
34
35 // Operands of fir.shape_shift split into two vectors.
populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> & shapeVec,llvm::SmallVectorImpl<mlir::Value> & shiftVec,fir::ShapeShiftOp shift)36 static void populateShapeAndShift(llvm::SmallVectorImpl<mlir::Value> &shapeVec,
37 llvm::SmallVectorImpl<mlir::Value> &shiftVec,
38 fir::ShapeShiftOp shift) {
39 for (auto i = shift.getPairs().begin(), endIter = shift.getPairs().end();
40 i != endIter;) {
41 shiftVec.push_back(*i++);
42 shapeVec.push_back(*i++);
43 }
44 }
45
populateShift(llvm::SmallVectorImpl<mlir::Value> & vec,fir::ShiftOp shift)46 static void populateShift(llvm::SmallVectorImpl<mlir::Value> &vec,
47 fir::ShiftOp shift) {
48 vec.append(shift.getOrigins().begin(), shift.getOrigins().end());
49 }
50
51 namespace {
52
53 /// Convert fir.embox to the extended form where necessary.
54 ///
55 /// The embox operation can take arguments that specify multidimensional array
56 /// properties at runtime. These properties may be shared between distinct
57 /// objects that have the same properties. Before we lower these small DAGs to
58 /// LLVM-IR, we gather all the information into a single extended operation. For
59 /// example,
60 /// ```
61 /// %1 = fir.shape_shift %4, %5 : (index, index) -> !fir.shapeshift<1>
62 /// %2 = fir.slice %6, %7, %8 : (index, index, index) -> !fir.slice<1>
63 /// %3 = fir.embox %0 (%1) [%2] : (!fir.ref<!fir.array<?xi32>>,
64 /// !fir.shapeshift<1>, !fir.slice<1>) -> !fir.box<!fir.array<?xi32>>
65 /// ```
66 /// can be rewritten as
67 /// ```
68 /// %1 = fircg.ext_embox %0(%5) origin %4[%6, %7, %8] :
69 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index) ->
70 /// !fir.box<!fir.array<?xi32>>
71 /// ```
72 class EmboxConversion : public mlir::OpRewritePattern<fir::EmboxOp> {
73 public:
74 using OpRewritePattern::OpRewritePattern;
75
76 mlir::LogicalResult
matchAndRewrite(fir::EmboxOp embox,mlir::PatternRewriter & rewriter) const77 matchAndRewrite(fir::EmboxOp embox,
78 mlir::PatternRewriter &rewriter) const override {
79 // If the embox does not include a shape, then do not convert it
80 if (auto shapeVal = embox.getShape())
81 return rewriteDynamicShape(embox, rewriter, shapeVal);
82 if (auto boxTy = embox.getType().dyn_cast<fir::BoxType>())
83 if (auto seqTy = boxTy.getEleTy().dyn_cast<fir::SequenceType>())
84 if (seqTy.hasConstantShape())
85 return rewriteStaticShape(embox, rewriter, seqTy);
86 return mlir::failure();
87 }
88
rewriteStaticShape(fir::EmboxOp embox,mlir::PatternRewriter & rewriter,fir::SequenceType seqTy) const89 mlir::LogicalResult rewriteStaticShape(fir::EmboxOp embox,
90 mlir::PatternRewriter &rewriter,
91 fir::SequenceType seqTy) const {
92 auto loc = embox.getLoc();
93 llvm::SmallVector<mlir::Value> shapeOpers;
94 auto idxTy = rewriter.getIndexType();
95 for (auto ext : seqTy.getShape()) {
96 auto iAttr = rewriter.getIndexAttr(ext);
97 auto extVal = rewriter.create<mlir::arith::ConstantOp>(loc, idxTy, iAttr);
98 shapeOpers.push_back(extVal);
99 }
100 auto xbox = rewriter.create<fir::cg::XEmboxOp>(
101 loc, embox.getType(), embox.getMemref(), shapeOpers, llvm::None,
102 llvm::None, llvm::None, llvm::None, embox.getTypeparams());
103 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
104 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
105 return mlir::success();
106 }
107
rewriteDynamicShape(fir::EmboxOp embox,mlir::PatternRewriter & rewriter,mlir::Value shapeVal) const108 mlir::LogicalResult rewriteDynamicShape(fir::EmboxOp embox,
109 mlir::PatternRewriter &rewriter,
110 mlir::Value shapeVal) const {
111 auto loc = embox.getLoc();
112 llvm::SmallVector<mlir::Value> shapeOpers;
113 llvm::SmallVector<mlir::Value> shiftOpers;
114 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp())) {
115 populateShape(shapeOpers, shapeOp);
116 } else {
117 auto shiftOp =
118 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp());
119 assert(shiftOp && "shape is neither fir.shape nor fir.shape_shift");
120 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
121 }
122 llvm::SmallVector<mlir::Value> sliceOpers;
123 llvm::SmallVector<mlir::Value> subcompOpers;
124 llvm::SmallVector<mlir::Value> substrOpers;
125 if (auto s = embox.getSlice())
126 if (auto sliceOp =
127 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
128 sliceOpers.assign(sliceOp.getTriples().begin(),
129 sliceOp.getTriples().end());
130 subcompOpers.assign(sliceOp.getFields().begin(),
131 sliceOp.getFields().end());
132 substrOpers.assign(sliceOp.getSubstr().begin(),
133 sliceOp.getSubstr().end());
134 }
135 auto xbox = rewriter.create<fir::cg::XEmboxOp>(
136 loc, embox.getType(), embox.getMemref(), shapeOpers, shiftOpers,
137 sliceOpers, subcompOpers, substrOpers, embox.getTypeparams());
138 LLVM_DEBUG(llvm::dbgs() << "rewriting " << embox << " to " << xbox << '\n');
139 rewriter.replaceOp(embox, xbox.getOperation()->getResults());
140 return mlir::success();
141 }
142 };
143
144 /// Convert fir.rebox to the extended form where necessary.
145 ///
146 /// For example,
147 /// ```
148 /// %5 = fir.rebox %3(%1) : (!fir.box<!fir.array<?xi32>>, !fir.shapeshift<1>) ->
149 /// !fir.box<!fir.array<?xi32>>
150 /// ```
151 /// converted to
152 /// ```
153 /// %5 = fircg.ext_rebox %3(%13) origin %12 : (!fir.box<!fir.array<?xi32>>,
154 /// index, index) -> !fir.box<!fir.array<?xi32>>
155 /// ```
156 class ReboxConversion : public mlir::OpRewritePattern<fir::ReboxOp> {
157 public:
158 using OpRewritePattern::OpRewritePattern;
159
160 mlir::LogicalResult
matchAndRewrite(fir::ReboxOp rebox,mlir::PatternRewriter & rewriter) const161 matchAndRewrite(fir::ReboxOp rebox,
162 mlir::PatternRewriter &rewriter) const override {
163 auto loc = rebox.getLoc();
164 llvm::SmallVector<mlir::Value> shapeOpers;
165 llvm::SmallVector<mlir::Value> shiftOpers;
166 if (auto shapeVal = rebox.getShape()) {
167 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
168 populateShape(shapeOpers, shapeOp);
169 else if (auto shiftOp =
170 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
171 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
172 else if (auto shiftOp =
173 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
174 populateShift(shiftOpers, shiftOp);
175 else
176 return mlir::failure();
177 }
178 llvm::SmallVector<mlir::Value> sliceOpers;
179 llvm::SmallVector<mlir::Value> subcompOpers;
180 llvm::SmallVector<mlir::Value> substrOpers;
181 if (auto s = rebox.getSlice())
182 if (auto sliceOp =
183 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
184 sliceOpers.append(sliceOp.getTriples().begin(),
185 sliceOp.getTriples().end());
186 subcompOpers.append(sliceOp.getFields().begin(),
187 sliceOp.getFields().end());
188 substrOpers.append(sliceOp.getSubstr().begin(),
189 sliceOp.getSubstr().end());
190 }
191
192 auto xRebox = rewriter.create<fir::cg::XReboxOp>(
193 loc, rebox.getType(), rebox.getBox(), shapeOpers, shiftOpers,
194 sliceOpers, subcompOpers, substrOpers);
195 LLVM_DEBUG(llvm::dbgs()
196 << "rewriting " << rebox << " to " << xRebox << '\n');
197 rewriter.replaceOp(rebox, xRebox.getOperation()->getResults());
198 return mlir::success();
199 }
200 };
201
202 /// Convert all fir.array_coor to the extended form.
203 ///
204 /// For example,
205 /// ```
206 /// %4 = fir.array_coor %addr (%1) [%2] %0 : (!fir.ref<!fir.array<?xi32>>,
207 /// !fir.shapeshift<1>, !fir.slice<1>, index) -> !fir.ref<i32>
208 /// ```
209 /// converted to
210 /// ```
211 /// %40 = fircg.ext_array_coor %addr(%9) origin %8[%4, %5, %6<%39> :
212 /// (!fir.ref<!fir.array<?xi32>>, index, index, index, index, index, index) ->
213 /// !fir.ref<i32>
214 /// ```
215 class ArrayCoorConversion : public mlir::OpRewritePattern<fir::ArrayCoorOp> {
216 public:
217 using OpRewritePattern::OpRewritePattern;
218
219 mlir::LogicalResult
matchAndRewrite(fir::ArrayCoorOp arrCoor,mlir::PatternRewriter & rewriter) const220 matchAndRewrite(fir::ArrayCoorOp arrCoor,
221 mlir::PatternRewriter &rewriter) const override {
222 auto loc = arrCoor.getLoc();
223 llvm::SmallVector<mlir::Value> shapeOpers;
224 llvm::SmallVector<mlir::Value> shiftOpers;
225 if (auto shapeVal = arrCoor.getShape()) {
226 if (auto shapeOp = mlir::dyn_cast<fir::ShapeOp>(shapeVal.getDefiningOp()))
227 populateShape(shapeOpers, shapeOp);
228 else if (auto shiftOp =
229 mlir::dyn_cast<fir::ShapeShiftOp>(shapeVal.getDefiningOp()))
230 populateShapeAndShift(shapeOpers, shiftOpers, shiftOp);
231 else if (auto shiftOp =
232 mlir::dyn_cast<fir::ShiftOp>(shapeVal.getDefiningOp()))
233 populateShift(shiftOpers, shiftOp);
234 else
235 return mlir::failure();
236 }
237 llvm::SmallVector<mlir::Value> sliceOpers;
238 llvm::SmallVector<mlir::Value> subcompOpers;
239 if (auto s = arrCoor.getSlice())
240 if (auto sliceOp =
241 mlir::dyn_cast_or_null<fir::SliceOp>(s.getDefiningOp())) {
242 sliceOpers.append(sliceOp.getTriples().begin(),
243 sliceOp.getTriples().end());
244 subcompOpers.append(sliceOp.getFields().begin(),
245 sliceOp.getFields().end());
246 assert(sliceOp.getSubstr().empty() &&
247 "Don't allow substring operations on array_coor. This "
248 "restriction may be lifted in the future.");
249 }
250 auto xArrCoor = rewriter.create<fir::cg::XArrayCoorOp>(
251 loc, arrCoor.getType(), arrCoor.getMemref(), shapeOpers, shiftOpers,
252 sliceOpers, subcompOpers, arrCoor.getIndices(),
253 arrCoor.getTypeparams());
254 LLVM_DEBUG(llvm::dbgs()
255 << "rewriting " << arrCoor << " to " << xArrCoor << '\n');
256 rewriter.replaceOp(arrCoor, xArrCoor.getOperation()->getResults());
257 return mlir::success();
258 }
259 };
260
261 class CodeGenRewrite : public fir::CodeGenRewriteBase<CodeGenRewrite> {
262 public:
runOn(mlir::Operation * op,mlir::Region & region)263 void runOn(mlir::Operation *op, mlir::Region ®ion) {
264 auto &context = getContext();
265 mlir::OpBuilder rewriter(&context);
266 mlir::ConversionTarget target(context);
267 target.addLegalDialect<mlir::arith::ArithmeticDialect, fir::FIROpsDialect,
268 fir::FIRCodeGenDialect, mlir::func::FuncDialect>();
269 target.addIllegalOp<fir::ArrayCoorOp>();
270 target.addIllegalOp<fir::ReboxOp>();
271 target.addDynamicallyLegalOp<fir::EmboxOp>([](fir::EmboxOp embox) {
272 return !(embox.getShape() || embox.getType()
273 .cast<fir::BoxType>()
274 .getEleTy()
275 .isa<fir::SequenceType>());
276 });
277 mlir::RewritePatternSet patterns(&context);
278 patterns.insert<EmboxConversion, ArrayCoorConversion, ReboxConversion>(
279 &context);
280 if (mlir::failed(
281 mlir::applyPartialConversion(op, target, std::move(patterns)))) {
282 mlir::emitError(mlir::UnknownLoc::get(&context),
283 "error in running the pre-codegen conversions");
284 signalPassFailure();
285 }
286 // Erase any residual.
287 simplifyRegion(region);
288 }
289
runOnOperation()290 void runOnOperation() override final {
291 // Call runOn on all top level regions that may contain emboxOp/arrayCoorOp.
292 auto mod = getOperation();
293 for (auto func : mod.getOps<mlir::func::FuncOp>())
294 runOn(func, func.getBody());
295 for (auto global : mod.getOps<fir::GlobalOp>())
296 runOn(global, global.getRegion());
297 }
298
299 // Clean up the region.
simplifyRegion(mlir::Region & region)300 void simplifyRegion(mlir::Region ®ion) {
301 for (auto &block : region.getBlocks())
302 for (auto &op : block.getOperations()) {
303 for (auto ® : op.getRegions())
304 simplifyRegion(reg);
305 maybeEraseOp(&op);
306 }
307 doDCE();
308 }
309
310 /// Run a simple DCE cleanup to remove any dead code after the rewrites.
doDCE()311 void doDCE() {
312 std::vector<mlir::Operation *> workList;
313 workList.swap(opsToErase);
314 while (!workList.empty()) {
315 for (auto *op : workList) {
316 std::vector<mlir::Value> opOperands(op->operand_begin(),
317 op->operand_end());
318 LLVM_DEBUG(llvm::dbgs() << "DCE on " << *op << '\n');
319 ++numDCE;
320 op->erase();
321 for (auto opnd : opOperands)
322 maybeEraseOp(opnd.getDefiningOp());
323 }
324 workList.clear();
325 workList.swap(opsToErase);
326 }
327 }
328
maybeEraseOp(mlir::Operation * op)329 void maybeEraseOp(mlir::Operation *op) {
330 if (!op)
331 return;
332 if (op->hasTrait<mlir::OpTrait::IsTerminator>())
333 return;
334 if (mlir::isOpTriviallyDead(op))
335 opsToErase.push_back(op);
336 }
337
338 private:
339 std::vector<mlir::Operation *> opsToErase;
340 };
341
342 } // namespace
343
createFirCodeGenRewritePass()344 std::unique_ptr<mlir::Pass> fir::createFirCodeGenRewritePass() {
345 return std::make_unique<CodeGenRewrite>();
346 }
347