1 //===-- AffinePromotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that promote FIR loops operations
10 // to affine dialect operations.
11 // It is not part of the production pipeline and would need more work in order
12 // to be used in production.
13 // More information can be found in this presentation:
14 // https://slides.com/rajanwalia/deck
15 //
16 //===----------------------------------------------------------------------===//
17
18 #include "PassDetail.h"
19 #include "flang/Optimizer/Dialect/FIRDialect.h"
20 #include "flang/Optimizer/Dialect/FIROps.h"
21 #include "flang/Optimizer/Dialect/FIRType.h"
22 #include "flang/Optimizer/Transforms/Passes.h"
23 #include "mlir/Dialect/Affine/IR/AffineOps.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"
25 #include "mlir/Dialect/SCF/IR/SCF.h"
26 #include "mlir/IR/BuiltinAttributes.h"
27 #include "mlir/IR/IntegerSet.h"
28 #include "mlir/IR/Visitors.h"
29 #include "mlir/Transforms/DialectConversion.h"
30 #include "llvm/ADT/DenseMap.h"
31 #include "llvm/ADT/Optional.h"
32 #include "llvm/Support/Debug.h"
33
34 #define DEBUG_TYPE "flang-affine-promotion"
35
36 using namespace fir;
37 using namespace mlir;
38
39 namespace {
40 struct AffineLoopAnalysis;
41 struct AffineIfAnalysis;
42
43 /// Stores analysis objects for all loops and if operations inside a function
44 /// these analysis are used twice, first for marking operations for rewrite and
45 /// second when doing rewrite.
46 struct AffineFunctionAnalysis {
AffineFunctionAnalysis__anonf1cdc1ed0111::AffineFunctionAnalysis47 explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
48 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
49 loopAnalysisMap.try_emplace(op, op, *this);
50 }
51
52 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
53
54 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
55
56 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
57 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
58 };
59 } // namespace
60
analyzeCoordinate(mlir::Value coordinate,mlir::Operation * op)61 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
62 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
63 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
64 return true;
65 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
66 "loop induction variable (owner not loopOp)\n";
67 op->dump());
68 return false;
69 }
70 LLVM_DEBUG(
71 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
72 "induction variable (not a block argument)\n";
73 op->dump(); coordinate.getDefiningOp()->dump());
74 return false;
75 }
76
77 namespace {
78 struct AffineLoopAnalysis {
79 AffineLoopAnalysis() = default;
80
AffineLoopAnalysis__anonf1cdc1ed0211::AffineLoopAnalysis81 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
82 : legality(analyzeLoop(op, afa)) {}
83
canPromoteToAffine__anonf1cdc1ed0211::AffineLoopAnalysis84 bool canPromoteToAffine() { return legality; }
85
86 private:
analyzeBody__anonf1cdc1ed0211::AffineLoopAnalysis87 bool analyzeBody(fir::DoLoopOp loopOperation,
88 AffineFunctionAnalysis &functionAnalysis) {
89 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
90 auto analysis = functionAnalysis.loopAnalysisMap
91 .try_emplace(loopOp, loopOp, functionAnalysis)
92 .first->getSecond();
93 if (!analysis.canPromoteToAffine())
94 return false;
95 }
96 for (auto ifOp : loopOperation.getOps<fir::IfOp>())
97 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
98 return true;
99 }
100
analyzeLoop__anonf1cdc1ed0211::AffineLoopAnalysis101 bool analyzeLoop(fir::DoLoopOp loopOperation,
102 AffineFunctionAnalysis &functionAnalysis) {
103 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
104 return analyzeMemoryAccess(loopOperation) &&
105 analyzeBody(loopOperation, functionAnalysis);
106 }
107
analyzeReference__anonf1cdc1ed0211::AffineLoopAnalysis108 bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
109 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
110 if (acoOp.getMemref().getType().isa<fir::BoxType>()) {
111 // TODO: Look if and how fir.box can be promoted to affine.
112 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
113 "array memory operation uses fir.box\n";
114 op->dump(); acoOp.dump(););
115 return false;
116 }
117 bool canPromote = true;
118 for (auto coordinate : acoOp.getIndices())
119 canPromote = canPromote && analyzeCoordinate(coordinate, op);
120 return canPromote;
121 }
122 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
123 LLVM_DEBUG(llvm::dbgs()
124 << "AffineLoopAnalysis: cannot promote loop, "
125 "array memory operation uses non ArrayCoorOp\n";
126 op->dump(); coOp.dump(););
127
128 return false;
129 }
130 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
131 "reference for array load\n";
132 op->dump(););
133 return false;
134 }
135
analyzeMemoryAccess__anonf1cdc1ed0211::AffineLoopAnalysis136 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
137 for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
138 if (!analyzeReference(loadOp.getMemref(), loadOp))
139 return false;
140 for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
141 if (!analyzeReference(storeOp.getMemref(), storeOp))
142 return false;
143 return true;
144 }
145
146 bool legality{};
147 };
148 } // namespace
149
150 AffineLoopAnalysis
getChildLoopAnalysis(fir::DoLoopOp op) const151 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
152 auto it = loopAnalysisMap.find_as(op);
153 if (it == loopAnalysisMap.end()) {
154 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
155 op.dump(););
156 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
157 return {};
158 }
159 return it->getSecond();
160 }
161
162 namespace {
163 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
164 /// final number of symbols and dimensions of the affine map. Integer set if
165 /// possible is in Optional IntegerSet.
166 struct AffineIfCondition {
167 using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
168
AffineIfCondition__anonf1cdc1ed0311::AffineIfCondition169 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
170 if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>())
171 fromCmpIOp(condDef);
172 }
173
hasIntegerSet__anonf1cdc1ed0311::AffineIfCondition174 bool hasIntegerSet() const { return integerSet.has_value(); }
175
getIntegerSet__anonf1cdc1ed0311::AffineIfCondition176 mlir::IntegerSet getIntegerSet() const {
177 assert(hasIntegerSet() && "integer set is missing");
178 return integerSet.value();
179 }
180
getAffineArgs__anonf1cdc1ed0311::AffineIfCondition181 mlir::ValueRange getAffineArgs() const { return affineArgs; }
182
183 private:
affineBinaryOp__anonf1cdc1ed0311::AffineIfCondition184 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
185 mlir::Value rhs) {
186 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
187 }
188
affineBinaryOp__anonf1cdc1ed0311::AffineIfCondition189 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
190 MaybeAffineExpr rhs) {
191 if (lhs && rhs)
192 return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs);
193 return {};
194 }
195
toAffineExpr__anonf1cdc1ed0311::AffineIfCondition196 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
197
toAffineExpr__anonf1cdc1ed0311::AffineIfCondition198 MaybeAffineExpr toAffineExpr(int64_t value) {
199 return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
200 }
201
202 /// Returns an AffineExpr if it is a result of operations that can be done
203 /// in an affine expression, this includes -, +, *, rem, constant.
204 /// block arguments of a loopOp or forOp are used as dimensions
toAffineExpr__anonf1cdc1ed0311::AffineIfCondition205 MaybeAffineExpr toAffineExpr(mlir::Value value) {
206 if (auto op = value.getDefiningOp<mlir::arith::SubIOp>())
207 return affineBinaryOp(
208 mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()),
209 affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()),
210 toAffineExpr(-1)));
211 if (auto op = value.getDefiningOp<mlir::arith::AddIOp>())
212 return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(),
213 op.getRhs());
214 if (auto op = value.getDefiningOp<mlir::arith::MulIOp>())
215 return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(),
216 op.getRhs());
217 if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>())
218 return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(),
219 op.getRhs());
220 if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>())
221 if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
222 return toAffineExpr(intConstant.getInt());
223 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
224 affineArgs.push_back(value);
225 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
226 isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
227 return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
228 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
229 }
230 return {};
231 }
232
fromCmpIOp__anonf1cdc1ed0311::AffineIfCondition233 void fromCmpIOp(mlir::arith::CmpIOp cmpOp) {
234 auto lhsAffine = toAffineExpr(cmpOp.getLhs());
235 auto rhsAffine = toAffineExpr(cmpOp.getRhs());
236 if (!lhsAffine || !rhsAffine)
237 return;
238 auto constraintPair =
239 constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine);
240 if (!constraintPair)
241 return;
242 integerSet = mlir::IntegerSet::get(
243 dimCount, symCount, {constraintPair->first}, {constraintPair->second});
244 }
245
246 llvm::Optional<std::pair<AffineExpr, bool>>
constraint__anonf1cdc1ed0311::AffineIfCondition247 constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) {
248 switch (predicate) {
249 case mlir::arith::CmpIPredicate::slt:
250 return {std::make_pair(basic - 1, false)};
251 case mlir::arith::CmpIPredicate::sle:
252 return {std::make_pair(basic, false)};
253 case mlir::arith::CmpIPredicate::sgt:
254 return {std::make_pair(1 - basic, false)};
255 case mlir::arith::CmpIPredicate::sge:
256 return {std::make_pair(0 - basic, false)};
257 case mlir::arith::CmpIPredicate::eq:
258 return {std::make_pair(basic, true)};
259 default:
260 return {};
261 }
262 }
263
264 llvm::SmallVector<mlir::Value> affineArgs;
265 llvm::Optional<mlir::IntegerSet> integerSet;
266 mlir::Value firCondition;
267 unsigned symCount{0u};
268 unsigned dimCount{0u};
269 };
270 } // namespace
271
272 namespace {
273 /// Analysis for affine promotion of fir.if
274 struct AffineIfAnalysis {
275 AffineIfAnalysis() = default;
276
AffineIfAnalysis__anonf1cdc1ed0411::AffineIfAnalysis277 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
278 : legality(analyzeIf(op, afa)) {}
279
canPromoteToAffine__anonf1cdc1ed0411::AffineIfAnalysis280 bool canPromoteToAffine() { return legality; }
281
282 private:
analyzeIf__anonf1cdc1ed0411::AffineIfAnalysis283 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
284 if (op.getNumResults() == 0)
285 return true;
286 LLVM_DEBUG(llvm::dbgs()
287 << "AffineIfAnalysis: not promoting as op has results\n";);
288 return false;
289 }
290
291 bool legality{};
292 };
293 } // namespace
294
295 AffineIfAnalysis
getChildIfAnalysis(fir::IfOp op) const296 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
297 auto it = ifAnalysisMap.find_as(op);
298 if (it == ifAnalysisMap.end()) {
299 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
300 op.dump(););
301 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
302 return {};
303 }
304 return it->getSecond();
305 }
306
307 /// AffineMap rewriting fir.array_coor operation to affine apply,
308 /// %dim = fir.gendim %lowerBound, %upperBound, %stride
309 /// %a = fir.array_coor %arr(%dim) %i
310 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
createArrayIndexAffineMap(unsigned dimensions,MLIRContext * context)311 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
312 MLIRContext *context) {
313 auto index = mlir::getAffineConstantExpr(0, context);
314 auto accuExtent = mlir::getAffineConstantExpr(1, context);
315 for (unsigned i = 0; i < dimensions; ++i) {
316 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
317 lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
318 currentExtent =
319 mlir::getAffineSymbolExpr(i * 3 + 1, context),
320 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
321 currentPart = (idx * stride - lowerBound) * accuExtent;
322 index = currentPart + index;
323 accuExtent = accuExtent * currentExtent;
324 }
325 return mlir::AffineMap::get(dimensions, dimensions * 3, index);
326 }
327
constantIntegerLike(const mlir::Value value)328 static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
329 if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
330 if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
331 return stepAttr.getInt();
332 return {};
333 }
334
coordinateArrayElement(fir::ArrayCoorOp op)335 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
336 if (auto refType =
337 op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) {
338 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
339 return seqType.getEleTy();
340 }
341 }
342 op.emitError(
343 "AffineLoopConversion: array type in coordinate operation not valid\n");
344 return mlir::Type();
345 }
346
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)347 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
348 SmallVectorImpl<mlir::Value> &indexArgs,
349 mlir::PatternRewriter &rewriter) {
350 auto one = rewriter.create<mlir::arith::ConstantOp>(
351 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
352 auto extents = shape.getExtents();
353 for (auto i = extents.begin(); i < extents.end(); i++) {
354 indexArgs.push_back(one);
355 indexArgs.push_back(*i);
356 indexArgs.push_back(one);
357 }
358 }
359
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeShiftOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)360 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
361 SmallVectorImpl<mlir::Value> &indexArgs,
362 mlir::PatternRewriter &rewriter) {
363 auto one = rewriter.create<mlir::arith::ConstantOp>(
364 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
365 auto extents = shape.getPairs();
366 for (auto i = extents.begin(); i < extents.end();) {
367 indexArgs.push_back(*i++);
368 indexArgs.push_back(*i++);
369 indexArgs.push_back(one);
370 }
371 }
372
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::SliceOp slice,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)373 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
374 SmallVectorImpl<mlir::Value> &indexArgs,
375 mlir::PatternRewriter &rewriter) {
376 auto extents = slice.getTriples();
377 for (auto i = extents.begin(); i < extents.end();) {
378 indexArgs.push_back(*i++);
379 indexArgs.push_back(*i++);
380 indexArgs.push_back(*i++);
381 }
382 }
383
populateIndexArgs(fir::ArrayCoorOp acoOp,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)384 static void populateIndexArgs(fir::ArrayCoorOp acoOp,
385 SmallVectorImpl<mlir::Value> &indexArgs,
386 mlir::PatternRewriter &rewriter) {
387 if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>())
388 return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
389 if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>())
390 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
391 if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>())
392 return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
393 }
394
395 /// Returns affine.apply and fir.convert from array_coor and gendims
396 static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
createAffineOps(mlir::Value arrayRef,mlir::PatternRewriter & rewriter)397 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
398 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
399 auto affineMap =
400 createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext());
401 SmallVector<mlir::Value> indexArgs;
402 indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end());
403
404 populateIndexArgs(acoOp, indexArgs, rewriter);
405
406 auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
407 affineMap, indexArgs);
408 auto arrayElementType = coordinateArrayElement(acoOp);
409 auto newType = mlir::MemRefType::get({-1}, arrayElementType);
410 auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType,
411 acoOp.getMemref());
412 return std::make_pair(affineApply, arrayConvert);
413 }
414
rewriteLoad(fir::LoadOp loadOp,mlir::PatternRewriter & rewriter)415 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
416 rewriter.setInsertionPoint(loadOp);
417 auto affineOps = createAffineOps(loadOp.getMemref(), rewriter);
418 rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
419 loadOp, affineOps.second.getResult(), affineOps.first.getResult());
420 }
421
rewriteStore(fir::StoreOp storeOp,mlir::PatternRewriter & rewriter)422 static void rewriteStore(fir::StoreOp storeOp,
423 mlir::PatternRewriter &rewriter) {
424 rewriter.setInsertionPoint(storeOp);
425 auto affineOps = createAffineOps(storeOp.getMemref(), rewriter);
426 rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(),
427 affineOps.second.getResult(),
428 affineOps.first.getResult());
429 }
430
rewriteMemoryOps(Block * block,mlir::PatternRewriter & rewriter)431 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
432 for (auto &bodyOp : block->getOperations()) {
433 if (isa<fir::LoadOp>(bodyOp))
434 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
435 if (isa<fir::StoreOp>(bodyOp))
436 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
437 }
438 }
439
440 namespace {
441 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
442 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
443 /// loads and stores to affine.
444 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
445 public:
446 using OpRewritePattern::OpRewritePattern;
AffineLoopConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)447 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
448 : OpRewritePattern(context), functionAnalysis(afa) {}
449
450 mlir::LogicalResult
matchAndRewrite(fir::DoLoopOp loop,mlir::PatternRewriter & rewriter) const451 matchAndRewrite(fir::DoLoopOp loop,
452 mlir::PatternRewriter &rewriter) const override {
453 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
454 loop.dump(););
455 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
456 functionAnalysis.getChildLoopAnalysis(loop);
457 auto &loopOps = loop.getBody()->getOperations();
458 auto loopAndIndex = createAffineFor(loop, rewriter);
459 auto affineFor = loopAndIndex.first;
460 auto inductionVar = loopAndIndex.second;
461
462 rewriter.startRootUpdate(affineFor.getOperation());
463 affineFor.getBody()->getOperations().splice(
464 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
465 std::prev(loopOps.end()));
466 rewriter.finalizeRootUpdate(affineFor.getOperation());
467
468 rewriter.startRootUpdate(loop.getOperation());
469 loop.getInductionVar().replaceAllUsesWith(inductionVar);
470 rewriter.finalizeRootUpdate(loop.getOperation());
471
472 rewriteMemoryOps(affineFor.getBody(), rewriter);
473
474 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
475 affineFor.dump(););
476 rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
477 return success();
478 }
479
480 private:
481 std::pair<mlir::AffineForOp, mlir::Value>
createAffineFor(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const482 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
483 if (auto constantStep = constantIntegerLike(op.getStep()))
484 if (*constantStep > 0)
485 return positiveConstantStep(op, *constantStep, rewriter);
486 return genericBounds(op, rewriter);
487 }
488
489 // when step for the loop is positive compile time constant
490 std::pair<mlir::AffineForOp, mlir::Value>
positiveConstantStep(fir::DoLoopOp op,int64_t step,mlir::PatternRewriter & rewriter) const491 positiveConstantStep(fir::DoLoopOp op, int64_t step,
492 mlir::PatternRewriter &rewriter) const {
493 auto affineFor = rewriter.create<mlir::AffineForOp>(
494 op.getLoc(), ValueRange(op.getLowerBound()),
495 mlir::AffineMap::get(0, 1,
496 mlir::getAffineSymbolExpr(0, op.getContext())),
497 ValueRange(op.getUpperBound()),
498 mlir::AffineMap::get(0, 1,
499 1 + mlir::getAffineSymbolExpr(0, op.getContext())),
500 step);
501 return std::make_pair(affineFor, affineFor.getInductionVar());
502 }
503
504 std::pair<mlir::AffineForOp, mlir::Value>
genericBounds(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const505 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
506 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
507 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
508 auto step = mlir::getAffineSymbolExpr(2, op.getContext());
509 mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
510 0, 3, (upperBound - lowerBound + step).floorDiv(step));
511 auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
512 op.getLoc(), upperBoundMap,
513 ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()}));
514 auto actualIndexMap = mlir::AffineMap::get(
515 1, 2,
516 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
517 mlir::getAffineSymbolExpr(1, op.getContext()));
518
519 auto affineFor = rewriter.create<mlir::AffineForOp>(
520 op.getLoc(), ValueRange(),
521 AffineMap::getConstantMap(0, op.getContext()),
522 genericUpperBound.getResult(),
523 mlir::AffineMap::get(0, 1,
524 1 + mlir::getAffineSymbolExpr(0, op.getContext())),
525 1);
526 rewriter.setInsertionPointToStart(affineFor.getBody());
527 auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
528 op.getLoc(), actualIndexMap,
529 ValueRange(
530 {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()}));
531 return std::make_pair(affineFor, actualIndex.getResult());
532 }
533
534 AffineFunctionAnalysis &functionAnalysis;
535 };
536
537 /// Convert `fir.if` to `affine.if`.
538 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
539 public:
540 using OpRewritePattern::OpRewritePattern;
AffineIfConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)541 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
542 : OpRewritePattern(context) {}
543 mlir::LogicalResult
matchAndRewrite(fir::IfOp op,mlir::PatternRewriter & rewriter) const544 matchAndRewrite(fir::IfOp op,
545 mlir::PatternRewriter &rewriter) const override {
546 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
547 op.dump(););
548 auto &ifOps = op.getThenRegion().front().getOperations();
549 auto affineCondition = AffineIfCondition(op.getCondition());
550 if (!affineCondition.hasIntegerSet()) {
551 LLVM_DEBUG(
552 llvm::dbgs()
553 << "AffineIfConversion: couldn't calculate affine condition\n";);
554 return failure();
555 }
556 auto affineIf = rewriter.create<mlir::AffineIfOp>(
557 op.getLoc(), affineCondition.getIntegerSet(),
558 affineCondition.getAffineArgs(), !op.getElseRegion().empty());
559 rewriter.startRootUpdate(affineIf);
560 affineIf.getThenBlock()->getOperations().splice(
561 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
562 std::prev(ifOps.end()));
563 if (!op.getElseRegion().empty()) {
564 auto &otherOps = op.getElseRegion().front().getOperations();
565 affineIf.getElseBlock()->getOperations().splice(
566 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
567 std::prev(otherOps.end()));
568 }
569 rewriter.finalizeRootUpdate(affineIf);
570 rewriteMemoryOps(affineIf.getBody(), rewriter);
571
572 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
573 affineIf.dump(););
574 rewriter.replaceOp(op, affineIf.getOperation()->getResults());
575 return success();
576 }
577 };
578
579 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
580 /// where such a promotion is possible.
581 class AffineDialectPromotion
582 : public AffineDialectPromotionBase<AffineDialectPromotion> {
583 public:
runOnOperation()584 void runOnOperation() override {
585
586 auto *context = &getContext();
587 auto function = getOperation();
588 markAllAnalysesPreserved();
589 auto functionAnalysis = AffineFunctionAnalysis(function);
590 mlir::RewritePatternSet patterns(context);
591 patterns.insert<AffineIfConversion>(context, functionAnalysis);
592 patterns.insert<AffineLoopConversion>(context, functionAnalysis);
593 mlir::ConversionTarget target = *context;
594 target.addLegalDialect<
595 mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect,
596 mlir::arith::ArithmeticDialect, mlir::func::FuncDialect>();
597 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
598 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
599 });
600 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
601 fir::DoLoopOp op) {
602 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
603 });
604
605 LLVM_DEBUG(llvm::dbgs()
606 << "AffineDialectPromotion: running promotion on: \n";
607 function.print(llvm::dbgs()););
608 // apply the patterns
609 if (mlir::failed(mlir::applyPartialConversion(function, target,
610 std::move(patterns)))) {
611 mlir::emitError(mlir::UnknownLoc::get(context),
612 "error in converting to affine dialect\n");
613 signalPassFailure();
614 }
615 }
616 };
617 } // namespace
618
619 /// Convert FIR loop constructs to the Affine dialect
createPromoteToAffinePass()620 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
621 return std::make_unique<AffineDialectPromotion>();
622 }
623