1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that promote FIR loops operations
10 // to affine dialect operations.
11 // It is not part of the production pipeline and would need more work in order
12 // to be used in production.
13 // More information can be found in this presentation:
14 // https://slides.com/rajanwalia/deck
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "PassDetail.h"
19 #include "flang/Optimizer/Dialect/FIRDialect.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Transforms/Passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/SCF/SCF.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"
26 #include "mlir/IR/BuiltinAttributes.h"
27 #include "mlir/IR/IntegerSet.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/Support/Debug.h"
33 
34 #define DEBUG_TYPE "flang-affine-promotion"
35 
36 using namespace fir;
37 
38 namespace {
39 struct AffineLoopAnalysis;
40 struct AffineIfAnalysis;
41 
42 /// Stores analysis objects for all loops and if operations inside a function
43 /// these analysis are used twice, first for marking operations for rewrite and
44 /// second when doing rewrite.
45 struct AffineFunctionAnalysis {
46   explicit AffineFunctionAnalysis(mlir::FuncOp funcOp) {
47     for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
48       loopAnalysisMap.try_emplace(op, op, *this);
49   }
50 
51   AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
52 
53   AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
54 
55   llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
56   llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
57 };
58 } // namespace
59 
60 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
61   if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
62     if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
63       return true;
64     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
65                                "loop induction variable (owner not loopOp)\n";
66                op->dump());
67     return false;
68   }
69   LLVM_DEBUG(
70       llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
71                       "induction variable (not a block argument)\n";
72       op->dump(); coordinate.getDefiningOp()->dump());
73   return false;
74 }
75 
76 namespace {
77 struct AffineLoopAnalysis {
78   AffineLoopAnalysis() = default;
79 
80   explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
81       : legality(analyzeLoop(op, afa)) {}
82 
83   bool canPromoteToAffine() { return legality; }
84 
85 private:
86   bool analyzeBody(fir::DoLoopOp loopOperation,
87                    AffineFunctionAnalysis &functionAnalysis) {
88     for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
89       auto analysis = functionAnalysis.loopAnalysisMap
90                           .try_emplace(loopOp, loopOp, functionAnalysis)
91                           .first->getSecond();
92       if (!analysis.canPromoteToAffine())
93         return false;
94     }
95     for (auto ifOp : loopOperation.getOps<fir::IfOp>())
96       functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
97     return true;
98   }
99 
100   bool analyzeLoop(fir::DoLoopOp loopOperation,
101                    AffineFunctionAnalysis &functionAnalysis) {
102     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
103     return analyzeMemoryAccess(loopOperation) &&
104            analyzeBody(loopOperation, functionAnalysis);
105   }
106 
107   bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
108     if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
109       if (acoOp.memref().getType().isa<fir::BoxType>()) {
110         // TODO: Look if and how fir.box can be promoted to affine.
111         LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
112                                    "array memory operation uses fir.box\n";
113                    op->dump(); acoOp.dump(););
114         return false;
115       }
116       bool canPromote = true;
117       for (auto coordinate : acoOp.indices())
118         canPromote = canPromote && analyzeCoordinate(coordinate, op);
119       return canPromote;
120     }
121     if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
122       LLVM_DEBUG(llvm::dbgs()
123                      << "AffineLoopAnalysis: cannot promote loop, "
124                         "array memory operation uses non ArrayCoorOp\n";
125                  op->dump(); coOp.dump(););
126 
127       return false;
128     }
129     LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
130                                "reference for array load\n";
131                op->dump(););
132     return false;
133   }
134 
135   bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
136     for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
137       if (!analyzeReference(loadOp.memref(), loadOp))
138         return false;
139     for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
140       if (!analyzeReference(storeOp.memref(), storeOp))
141         return false;
142     return true;
143   }
144 
145   bool legality{};
146 };
147 } // namespace
148 
149 AffineLoopAnalysis
150 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
151   auto it = loopAnalysisMap.find_as(op);
152   if (it == loopAnalysisMap.end()) {
153     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
154                op.dump(););
155     op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
156     return {};
157   }
158   return it->getSecond();
159 }
160 
161 namespace {
162 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
163 /// final number of symbols and dimensions of the affine map. Integer set if
164 /// possible is in Optional IntegerSet.
165 struct AffineIfCondition {
166   using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
167 
168   explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
169     if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>())
170       fromCmpIOp(condDef);
171   }
172 
173   bool hasIntegerSet() const { return integerSet.hasValue(); }
174 
175   mlir::IntegerSet getIntegerSet() const {
176     assert(hasIntegerSet() && "integer set is missing");
177     return integerSet.getValue();
178   }
179 
180   mlir::ValueRange getAffineArgs() const { return affineArgs; }
181 
182 private:
183   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
184                                  mlir::Value rhs) {
185     return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
186   }
187 
188   MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
189                                  MaybeAffineExpr rhs) {
190     if (lhs.hasValue() && rhs.hasValue())
191       return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue());
192     return {};
193   }
194 
195   MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
196 
197   MaybeAffineExpr toAffineExpr(int64_t value) {
198     return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
199   }
200 
201   /// Returns an AffineExpr if it is a result of operations that can be done
202   /// in an affine expression, this includes -, +, *, rem, constant.
203   /// block arguments of a loopOp or forOp are used as dimensions
204   MaybeAffineExpr toAffineExpr(mlir::Value value) {
205     if (auto op = value.getDefiningOp<mlir::arith::SubIOp>())
206       return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()),
207                             affineBinaryOp(mlir::AffineExprKind::Mul,
208                                            toAffineExpr(op.rhs()),
209                                            toAffineExpr(-1)));
210     if (auto op = value.getDefiningOp<mlir::arith::AddIOp>())
211       return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs());
212     if (auto op = value.getDefiningOp<mlir::arith::MulIOp>())
213       return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs());
214     if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>())
215       return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs());
216     if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>())
217       if (auto intConstant = op.value().dyn_cast<IntegerAttr>())
218         return toAffineExpr(intConstant.getInt());
219     if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
220       affineArgs.push_back(value);
221       if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
222           isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
223         return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
224       return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
225     }
226     return {};
227   }
228 
229   void fromCmpIOp(mlir::arith::CmpIOp cmpOp) {
230     auto lhsAffine = toAffineExpr(cmpOp.lhs());
231     auto rhsAffine = toAffineExpr(cmpOp.rhs());
232     if (!lhsAffine.hasValue() || !rhsAffine.hasValue())
233       return;
234     auto constraintPair = constraint(
235         cmpOp.predicate(), rhsAffine.getValue() - lhsAffine.getValue());
236     if (!constraintPair)
237       return;
238     integerSet = mlir::IntegerSet::get(dimCount, symCount,
239                                        {constraintPair.getValue().first},
240                                        {constraintPair.getValue().second});
241     return;
242   }
243 
244   llvm::Optional<std::pair<AffineExpr, bool>>
245   constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) {
246     switch (predicate) {
247     case mlir::arith::CmpIPredicate::slt:
248       return {std::make_pair(basic - 1, false)};
249     case mlir::arith::CmpIPredicate::sle:
250       return {std::make_pair(basic, false)};
251     case mlir::arith::CmpIPredicate::sgt:
252       return {std::make_pair(1 - basic, false)};
253     case mlir::arith::CmpIPredicate::sge:
254       return {std::make_pair(0 - basic, false)};
255     case mlir::arith::CmpIPredicate::eq:
256       return {std::make_pair(basic, true)};
257     default:
258       return {};
259     }
260   }
261 
262   llvm::SmallVector<mlir::Value> affineArgs;
263   llvm::Optional<mlir::IntegerSet> integerSet;
264   mlir::Value firCondition;
265   unsigned symCount{0u};
266   unsigned dimCount{0u};
267 };
268 } // namespace
269 
270 namespace {
271 /// Analysis for affine promotion of fir.if
272 struct AffineIfAnalysis {
273   AffineIfAnalysis() = default;
274 
275   explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
276       : legality(analyzeIf(op, afa)) {}
277 
278   bool canPromoteToAffine() { return legality; }
279 
280 private:
281   bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
282     if (op.getNumResults() == 0)
283       return true;
284     LLVM_DEBUG(llvm::dbgs()
285                    << "AffineIfAnalysis: not promoting as op has results\n";);
286     return false;
287   }
288 
289   bool legality{};
290 };
291 } // namespace
292 
293 AffineIfAnalysis
294 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
295   auto it = ifAnalysisMap.find_as(op);
296   if (it == ifAnalysisMap.end()) {
297     LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
298                op.dump(););
299     op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
300     return {};
301   }
302   return it->getSecond();
303 }
304 
305 /// AffineMap rewriting fir.array_coor operation to affine apply,
306 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
307 /// %a = fir.array_coor %arr(%dim) %i
308 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
309 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
310                                                  MLIRContext *context) {
311   auto index = mlir::getAffineConstantExpr(0, context);
312   auto accuExtent = mlir::getAffineConstantExpr(1, context);
313   for (unsigned i = 0; i < dimensions; ++i) {
314     mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
315                      lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
316                      currentExtent =
317                          mlir::getAffineSymbolExpr(i * 3 + 1, context),
318                      stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
319                      currentPart = (idx * stride - lowerBound) * accuExtent;
320     index = currentPart + index;
321     accuExtent = accuExtent * currentExtent;
322   }
323   return mlir::AffineMap::get(dimensions, dimensions * 3, index);
324 }
325 
326 static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
327   if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
328     if (auto stepAttr = definition.value().dyn_cast<IntegerAttr>())
329       return stepAttr.getInt();
330   return {};
331 }
332 
333 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
334   if (auto refType = op.memref().getType().dyn_cast_or_null<ReferenceType>()) {
335     if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
336       return seqType.getEleTy();
337     }
338   }
339   op.emitError(
340       "AffineLoopConversion: array type in coordinate operation not valid\n");
341   return mlir::Type();
342 }
343 
344 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
345                               SmallVectorImpl<mlir::Value> &indexArgs,
346                               mlir::PatternRewriter &rewriter) {
347   auto one = rewriter.create<mlir::arith::ConstantOp>(
348       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
349   auto extents = shape.extents();
350   for (auto i = extents.begin(); i < extents.end(); i++) {
351     indexArgs.push_back(one);
352     indexArgs.push_back(*i);
353     indexArgs.push_back(one);
354   }
355 }
356 
357 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
358                               SmallVectorImpl<mlir::Value> &indexArgs,
359                               mlir::PatternRewriter &rewriter) {
360   auto one = rewriter.create<mlir::arith::ConstantOp>(
361       acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
362   auto extents = shape.pairs();
363   for (auto i = extents.begin(); i < extents.end();) {
364     indexArgs.push_back(*i++);
365     indexArgs.push_back(*i++);
366     indexArgs.push_back(one);
367   }
368 }
369 
370 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
371                               SmallVectorImpl<mlir::Value> &indexArgs,
372                               mlir::PatternRewriter &rewriter) {
373   auto extents = slice.triples();
374   for (auto i = extents.begin(); i < extents.end();) {
375     indexArgs.push_back(*i++);
376     indexArgs.push_back(*i++);
377     indexArgs.push_back(*i++);
378   }
379 }
380 
381 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
382                               SmallVectorImpl<mlir::Value> &indexArgs,
383                               mlir::PatternRewriter &rewriter) {
384   if (auto shape = acoOp.shape().getDefiningOp<ShapeOp>())
385     return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
386   if (auto shapeShift = acoOp.shape().getDefiningOp<ShapeShiftOp>())
387     return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
388   if (auto slice = acoOp.shape().getDefiningOp<SliceOp>())
389     return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
390   return;
391 }
392 
393 /// Returns affine.apply and fir.convert from array_coor and gendims
394 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
395 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
396   auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
397   auto affineMap =
398       createArrayIndexAffineMap(acoOp.indices().size(), acoOp.getContext());
399   SmallVector<mlir::Value> indexArgs;
400   indexArgs.append(acoOp.indices().begin(), acoOp.indices().end());
401 
402   populateIndexArgs(acoOp, indexArgs, rewriter);
403 
404   auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
405                                                           affineMap, indexArgs);
406   auto arrayElementType = coordinateArrayElement(acoOp);
407   auto newType = mlir::MemRefType::get({-1}, arrayElementType);
408   auto arrayConvert =
409       rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, acoOp.memref());
410   return std::make_pair(affineApply, arrayConvert);
411 }
412 
413 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
414   rewriter.setInsertionPoint(loadOp);
415   auto affineOps = createAffineOps(loadOp.memref(), rewriter);
416   rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
417       loadOp, affineOps.second.getResult(), affineOps.first.getResult());
418 }
419 
420 static void rewriteStore(fir::StoreOp storeOp,
421                          mlir::PatternRewriter &rewriter) {
422   rewriter.setInsertionPoint(storeOp);
423   auto affineOps = createAffineOps(storeOp.memref(), rewriter);
424   rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.value(),
425                                                    affineOps.second.getResult(),
426                                                    affineOps.first.getResult());
427 }
428 
429 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
430   for (auto &bodyOp : block->getOperations()) {
431     if (isa<fir::LoadOp>(bodyOp))
432       rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
433     if (isa<fir::StoreOp>(bodyOp))
434       rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
435   }
436 }
437 
438 namespace {
439 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
440 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
441 /// loads and stores to affine.
442 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
443 public:
444   using OpRewritePattern::OpRewritePattern;
445   AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
446       : OpRewritePattern(context), functionAnalysis(afa) {}
447 
448   mlir::LogicalResult
449   matchAndRewrite(fir::DoLoopOp loop,
450                   mlir::PatternRewriter &rewriter) const override {
451     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
452                loop.dump(););
453     LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
454         functionAnalysis.getChildLoopAnalysis(loop);
455     auto &loopOps = loop.getBody()->getOperations();
456     auto loopAndIndex = createAffineFor(loop, rewriter);
457     auto affineFor = loopAndIndex.first;
458     auto inductionVar = loopAndIndex.second;
459 
460     rewriter.startRootUpdate(affineFor.getOperation());
461     affineFor.getBody()->getOperations().splice(
462         std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
463         std::prev(loopOps.end()));
464     rewriter.finalizeRootUpdate(affineFor.getOperation());
465 
466     rewriter.startRootUpdate(loop.getOperation());
467     loop.getInductionVar().replaceAllUsesWith(inductionVar);
468     rewriter.finalizeRootUpdate(loop.getOperation());
469 
470     rewriteMemoryOps(affineFor.getBody(), rewriter);
471 
472     LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
473                affineFor.dump(););
474     rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
475     return success();
476   }
477 
478 private:
479   std::pair<mlir::AffineForOp, mlir::Value>
480   createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
481     if (auto constantStep = constantIntegerLike(op.step()))
482       if (constantStep.getValue() > 0)
483         return positiveConstantStep(op, constantStep.getValue(), rewriter);
484     return genericBounds(op, rewriter);
485   }
486 
487   // when step for the loop is positive compile time constant
488   std::pair<mlir::AffineForOp, mlir::Value>
489   positiveConstantStep(fir::DoLoopOp op, int64_t step,
490                        mlir::PatternRewriter &rewriter) const {
491     auto affineFor = rewriter.create<mlir::AffineForOp>(
492         op.getLoc(), ValueRange(op.lowerBound()),
493         mlir::AffineMap::get(0, 1,
494                              mlir::getAffineSymbolExpr(0, op.getContext())),
495         ValueRange(op.upperBound()),
496         mlir::AffineMap::get(0, 1,
497                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
498         step);
499     return std::make_pair(affineFor, affineFor.getInductionVar());
500   }
501 
502   std::pair<mlir::AffineForOp, mlir::Value>
503   genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
504     auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
505     auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
506     auto step = mlir::getAffineSymbolExpr(2, op.getContext());
507     mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
508         0, 3, (upperBound - lowerBound + step).floorDiv(step));
509     auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
510         op.getLoc(), upperBoundMap,
511         ValueRange({op.lowerBound(), op.upperBound(), op.step()}));
512     auto actualIndexMap = mlir::AffineMap::get(
513         1, 2,
514         (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
515             mlir::getAffineSymbolExpr(1, op.getContext()));
516 
517     auto affineFor = rewriter.create<mlir::AffineForOp>(
518         op.getLoc(), ValueRange(),
519         AffineMap::getConstantMap(0, op.getContext()),
520         genericUpperBound.getResult(),
521         mlir::AffineMap::get(0, 1,
522                              1 + mlir::getAffineSymbolExpr(0, op.getContext())),
523         1);
524     rewriter.setInsertionPointToStart(affineFor.getBody());
525     auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
526         op.getLoc(), actualIndexMap,
527         ValueRange({affineFor.getInductionVar(), op.lowerBound(), op.step()}));
528     return std::make_pair(affineFor, actualIndex.getResult());
529   }
530 
531   AffineFunctionAnalysis &functionAnalysis;
532 };
533 
534 /// Convert `fir.if` to `affine.if`.
535 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
536 public:
537   using OpRewritePattern::OpRewritePattern;
538   AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
539       : OpRewritePattern(context) {}
540   mlir::LogicalResult
541   matchAndRewrite(fir::IfOp op,
542                   mlir::PatternRewriter &rewriter) const override {
543     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
544                op.dump(););
545     auto &ifOps = op.thenRegion().front().getOperations();
546     auto affineCondition = AffineIfCondition(op.condition());
547     if (!affineCondition.hasIntegerSet()) {
548       LLVM_DEBUG(
549           llvm::dbgs()
550               << "AffineIfConversion: couldn't calculate affine condition\n";);
551       return failure();
552     }
553     auto affineIf = rewriter.create<mlir::AffineIfOp>(
554         op.getLoc(), affineCondition.getIntegerSet(),
555         affineCondition.getAffineArgs(), !op.elseRegion().empty());
556     rewriter.startRootUpdate(affineIf);
557     affineIf.getThenBlock()->getOperations().splice(
558         std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
559         std::prev(ifOps.end()));
560     if (!op.elseRegion().empty()) {
561       auto &otherOps = op.elseRegion().front().getOperations();
562       affineIf.getElseBlock()->getOperations().splice(
563           std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
564           std::prev(otherOps.end()));
565     }
566     rewriter.finalizeRootUpdate(affineIf);
567     rewriteMemoryOps(affineIf.getBody(), rewriter);
568 
569     LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
570                affineIf.dump(););
571     rewriter.replaceOp(op, affineIf.getOperation()->getResults());
572     return success();
573   }
574 };
575 
576 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
577 /// where such a promotion is possible.
578 class AffineDialectPromotion
579     : public AffineDialectPromotionBase<AffineDialectPromotion> {
580 public:
581   void runOnFunction() override {
582 
583     auto *context = &getContext();
584     auto function = getFunction();
585     markAllAnalysesPreserved();
586     auto functionAnalysis = AffineFunctionAnalysis(function);
587     mlir::OwningRewritePatternList patterns(context);
588     patterns.insert<AffineIfConversion>(context, functionAnalysis);
589     patterns.insert<AffineLoopConversion>(context, functionAnalysis);
590     mlir::ConversionTarget target = *context;
591     target.addLegalDialect<
592         mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect,
593         mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect>();
594     target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
595       return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
596     });
597     target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
598                                                fir::DoLoopOp op) {
599       return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
600     });
601 
602     LLVM_DEBUG(llvm::dbgs()
603                    << "AffineDialectPromotion: running promotion on: \n";
604                function.print(llvm::dbgs()););
605     // apply the patterns
606     if (mlir::failed(mlir::applyPartialConversion(function, target,
607                                                   std::move(patterns)))) {
608       mlir::emitError(mlir::UnknownLoc::get(context),
609                       "error in converting to affine dialect\n");
610       signalPassFailure();
611     }
612   }
613 };
614 } // namespace
615 
616 /// Convert FIR loop constructs to the Affine dialect
617 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
618   return std::make_unique<AffineDialectPromotion>();
619 }
620