1d70938bbSJean Perier //===-- AffinePromotion.cpp -----------------------------------------------===//
2d70938bbSJean Perier //
3d70938bbSJean Perier // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d70938bbSJean Perier // See https://llvm.org/LICENSE.txt for license information.
5d70938bbSJean Perier // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d70938bbSJean Perier //
7d70938bbSJean Perier //===----------------------------------------------------------------------===//
8b2169992SValentin Clement //
9b2169992SValentin Clement // This transformation is a prototype that promote FIR loops operations
10b2169992SValentin Clement // to affine dialect operations.
11b2169992SValentin Clement // It is not part of the production pipeline and would need more work in order
12b2169992SValentin Clement // to be used in production.
13b2169992SValentin Clement // More information can be found in this presentation:
14b2169992SValentin Clement // https://slides.com/rajanwalia/deck
15b2169992SValentin Clement //
16b2169992SValentin Clement //===----------------------------------------------------------------------===//
17d70938bbSJean Perier
18d70938bbSJean Perier #include "PassDetail.h"
19d70938bbSJean Perier #include "flang/Optimizer/Dialect/FIRDialect.h"
20d70938bbSJean Perier #include "flang/Optimizer/Dialect/FIROps.h"
21d70938bbSJean Perier #include "flang/Optimizer/Dialect/FIRType.h"
22d70938bbSJean Perier #include "flang/Optimizer/Transforms/Passes.h"
23d70938bbSJean Perier #include "mlir/Dialect/Affine/IR/AffineOps.h"
2423aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
258b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
26d70938bbSJean Perier #include "mlir/IR/BuiltinAttributes.h"
27d70938bbSJean Perier #include "mlir/IR/IntegerSet.h"
28d70938bbSJean Perier #include "mlir/IR/Visitors.h"
29d70938bbSJean Perier #include "mlir/Transforms/DialectConversion.h"
30d70938bbSJean Perier #include "llvm/ADT/DenseMap.h"
31d70938bbSJean Perier #include "llvm/ADT/Optional.h"
32d70938bbSJean Perier #include "llvm/Support/Debug.h"
33d70938bbSJean Perier
34d70938bbSJean Perier #define DEBUG_TYPE "flang-affine-promotion"
35d70938bbSJean Perier
36d70938bbSJean Perier using namespace fir;
37092601d4SAndrzej Warzynski using namespace mlir;
38d70938bbSJean Perier
39d70938bbSJean Perier namespace {
40d70938bbSJean Perier struct AffineLoopAnalysis;
41d70938bbSJean Perier struct AffineIfAnalysis;
42d70938bbSJean Perier
43d70938bbSJean Perier /// Stores analysis objects for all loops and if operations inside a function
44d70938bbSJean Perier /// these analysis are used twice, first for marking operations for rewrite and
45d70938bbSJean Perier /// second when doing rewrite.
46d70938bbSJean Perier struct AffineFunctionAnalysis {
AffineFunctionAnalysis__anonf1cdc1ed0111::AffineFunctionAnalysis4758ceae95SRiver Riddle explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
48d70938bbSJean Perier for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
49d70938bbSJean Perier loopAnalysisMap.try_emplace(op, op, *this);
50d70938bbSJean Perier }
51d70938bbSJean Perier
52d70938bbSJean Perier AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
53d70938bbSJean Perier
54d70938bbSJean Perier AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const;
55d70938bbSJean Perier
56d70938bbSJean Perier llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap;
57d70938bbSJean Perier llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap;
58d70938bbSJean Perier };
59d70938bbSJean Perier } // namespace
60d70938bbSJean Perier
analyzeCoordinate(mlir::Value coordinate,mlir::Operation * op)61d70938bbSJean Perier static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) {
62d70938bbSJean Perier if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) {
63d70938bbSJean Perier if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()))
64d70938bbSJean Perier return true;
65d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a "
66d70938bbSJean Perier "loop induction variable (owner not loopOp)\n";
67d70938bbSJean Perier op->dump());
68d70938bbSJean Perier return false;
69d70938bbSJean Perier }
70d70938bbSJean Perier LLVM_DEBUG(
71d70938bbSJean Perier llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop "
72d70938bbSJean Perier "induction variable (not a block argument)\n";
73d70938bbSJean Perier op->dump(); coordinate.getDefiningOp()->dump());
74d70938bbSJean Perier return false;
75d70938bbSJean Perier }
76d70938bbSJean Perier
77d70938bbSJean Perier namespace {
78d70938bbSJean Perier struct AffineLoopAnalysis {
79d70938bbSJean Perier AffineLoopAnalysis() = default;
80d70938bbSJean Perier
AffineLoopAnalysis__anonf1cdc1ed0211::AffineLoopAnalysis81d70938bbSJean Perier explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa)
82d70938bbSJean Perier : legality(analyzeLoop(op, afa)) {}
83d70938bbSJean Perier
canPromoteToAffine__anonf1cdc1ed0211::AffineLoopAnalysis84d70938bbSJean Perier bool canPromoteToAffine() { return legality; }
85d70938bbSJean Perier
86d70938bbSJean Perier private:
analyzeBody__anonf1cdc1ed0211::AffineLoopAnalysis87d70938bbSJean Perier bool analyzeBody(fir::DoLoopOp loopOperation,
88d70938bbSJean Perier AffineFunctionAnalysis &functionAnalysis) {
89d70938bbSJean Perier for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) {
90d70938bbSJean Perier auto analysis = functionAnalysis.loopAnalysisMap
91d70938bbSJean Perier .try_emplace(loopOp, loopOp, functionAnalysis)
92d70938bbSJean Perier .first->getSecond();
93d70938bbSJean Perier if (!analysis.canPromoteToAffine())
94d70938bbSJean Perier return false;
95d70938bbSJean Perier }
96d70938bbSJean Perier for (auto ifOp : loopOperation.getOps<fir::IfOp>())
97d70938bbSJean Perier functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis);
98d70938bbSJean Perier return true;
99d70938bbSJean Perier }
100d70938bbSJean Perier
analyzeLoop__anonf1cdc1ed0211::AffineLoopAnalysis101d70938bbSJean Perier bool analyzeLoop(fir::DoLoopOp loopOperation,
102d70938bbSJean Perier AffineFunctionAnalysis &functionAnalysis) {
103d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
104d70938bbSJean Perier return analyzeMemoryAccess(loopOperation) &&
105d70938bbSJean Perier analyzeBody(loopOperation, functionAnalysis);
106d70938bbSJean Perier }
107d70938bbSJean Perier
analyzeReference__anonf1cdc1ed0211::AffineLoopAnalysis108d70938bbSJean Perier bool analyzeReference(mlir::Value memref, mlir::Operation *op) {
109d70938bbSJean Perier if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) {
110149ad3d5SShraiysh Vaishay if (acoOp.getMemref().getType().isa<fir::BoxType>()) {
111d70938bbSJean Perier // TODO: Look if and how fir.box can be promoted to affine.
112d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, "
113d70938bbSJean Perier "array memory operation uses fir.box\n";
114d70938bbSJean Perier op->dump(); acoOp.dump(););
115d70938bbSJean Perier return false;
116d70938bbSJean Perier }
117d70938bbSJean Perier bool canPromote = true;
118149ad3d5SShraiysh Vaishay for (auto coordinate : acoOp.getIndices())
119d70938bbSJean Perier canPromote = canPromote && analyzeCoordinate(coordinate, op);
120d70938bbSJean Perier return canPromote;
121d70938bbSJean Perier }
122d70938bbSJean Perier if (auto coOp = memref.getDefiningOp<CoordinateOp>()) {
123d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs()
124d70938bbSJean Perier << "AffineLoopAnalysis: cannot promote loop, "
125d70938bbSJean Perier "array memory operation uses non ArrayCoorOp\n";
126d70938bbSJean Perier op->dump(); coOp.dump(););
127d70938bbSJean Perier
128d70938bbSJean Perier return false;
129d70938bbSJean Perier }
130d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory "
131d70938bbSJean Perier "reference for array load\n";
132d70938bbSJean Perier op->dump(););
133d70938bbSJean Perier return false;
134d70938bbSJean Perier }
135d70938bbSJean Perier
analyzeMemoryAccess__anonf1cdc1ed0211::AffineLoopAnalysis136d70938bbSJean Perier bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) {
137d70938bbSJean Perier for (auto loadOp : loopOperation.getOps<fir::LoadOp>())
138149ad3d5SShraiysh Vaishay if (!analyzeReference(loadOp.getMemref(), loadOp))
139d70938bbSJean Perier return false;
140d70938bbSJean Perier for (auto storeOp : loopOperation.getOps<fir::StoreOp>())
141149ad3d5SShraiysh Vaishay if (!analyzeReference(storeOp.getMemref(), storeOp))
142d70938bbSJean Perier return false;
143d70938bbSJean Perier return true;
144d70938bbSJean Perier }
145d70938bbSJean Perier
146d70938bbSJean Perier bool legality{};
147d70938bbSJean Perier };
148d70938bbSJean Perier } // namespace
149d70938bbSJean Perier
150d70938bbSJean Perier AffineLoopAnalysis
getChildLoopAnalysis(fir::DoLoopOp op) const151d70938bbSJean Perier AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const {
152d70938bbSJean Perier auto it = loopAnalysisMap.find_as(op);
153d70938bbSJean Perier if (it == loopAnalysisMap.end()) {
154d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
155d70938bbSJean Perier op.dump(););
156d70938bbSJean Perier op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n");
157d70938bbSJean Perier return {};
158d70938bbSJean Perier }
159d70938bbSJean Perier return it->getSecond();
160d70938bbSJean Perier }
161d70938bbSJean Perier
162d70938bbSJean Perier namespace {
163d70938bbSJean Perier /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the
164d70938bbSJean Perier /// final number of symbols and dimensions of the affine map. Integer set if
165d70938bbSJean Perier /// possible is in Optional IntegerSet.
166d70938bbSJean Perier struct AffineIfCondition {
167d70938bbSJean Perier using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>;
168d70938bbSJean Perier
AffineIfCondition__anonf1cdc1ed0311::AffineIfCondition169d70938bbSJean Perier explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) {
170a54f4eaeSMogball if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>())
171d70938bbSJean Perier fromCmpIOp(condDef);
172d70938bbSJean Perier }
173d70938bbSJean Perier
hasIntegerSet__anonf1cdc1ed0311::AffineIfCondition174c82fb16fSKazu Hirata bool hasIntegerSet() const { return integerSet.has_value(); }
175d70938bbSJean Perier
getIntegerSet__anonf1cdc1ed0311::AffineIfCondition176d70938bbSJean Perier mlir::IntegerSet getIntegerSet() const {
177d70938bbSJean Perier assert(hasIntegerSet() && "integer set is missing");
178*3356d72aSKazu Hirata return integerSet.value();
179d70938bbSJean Perier }
180d70938bbSJean Perier
getAffineArgs__anonf1cdc1ed0311::AffineIfCondition181d70938bbSJean Perier mlir::ValueRange getAffineArgs() const { return affineArgs; }
182d70938bbSJean Perier
183d70938bbSJean Perier private:
affineBinaryOp__anonf1cdc1ed0311::AffineIfCondition184d70938bbSJean Perier MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs,
185d70938bbSJean Perier mlir::Value rhs) {
186d70938bbSJean Perier return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs));
187d70938bbSJean Perier }
188d70938bbSJean Perier
affineBinaryOp__anonf1cdc1ed0311::AffineIfCondition189d70938bbSJean Perier MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs,
190d70938bbSJean Perier MaybeAffineExpr rhs) {
19186b8c1d9SKazu Hirata if (lhs && rhs)
192009ab172SKazu Hirata return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs);
193d70938bbSJean Perier return {};
194d70938bbSJean Perier }
195d70938bbSJean Perier
toAffineExpr__anonf1cdc1ed0311::AffineIfCondition196d70938bbSJean Perier MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; }
197d70938bbSJean Perier
toAffineExpr__anonf1cdc1ed0311::AffineIfCondition198d70938bbSJean Perier MaybeAffineExpr toAffineExpr(int64_t value) {
199d70938bbSJean Perier return {mlir::getAffineConstantExpr(value, firCondition.getContext())};
200d70938bbSJean Perier }
201d70938bbSJean Perier
202d70938bbSJean Perier /// Returns an AffineExpr if it is a result of operations that can be done
203d70938bbSJean Perier /// in an affine expression, this includes -, +, *, rem, constant.
204d70938bbSJean Perier /// block arguments of a loopOp or forOp are used as dimensions
toAffineExpr__anonf1cdc1ed0311::AffineIfCondition205d70938bbSJean Perier MaybeAffineExpr toAffineExpr(mlir::Value value) {
206a54f4eaeSMogball if (auto op = value.getDefiningOp<mlir::arith::SubIOp>())
207feeee78aSJacques Pienaar return affineBinaryOp(
208feeee78aSJacques Pienaar mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()),
209feeee78aSJacques Pienaar affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()),
210d70938bbSJean Perier toAffineExpr(-1)));
211a54f4eaeSMogball if (auto op = value.getDefiningOp<mlir::arith::AddIOp>())
212feeee78aSJacques Pienaar return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(),
213feeee78aSJacques Pienaar op.getRhs());
214a54f4eaeSMogball if (auto op = value.getDefiningOp<mlir::arith::MulIOp>())
215feeee78aSJacques Pienaar return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(),
216feeee78aSJacques Pienaar op.getRhs());
217a54f4eaeSMogball if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>())
218feeee78aSJacques Pienaar return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(),
219feeee78aSJacques Pienaar op.getRhs());
220a54f4eaeSMogball if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>())
221feeee78aSJacques Pienaar if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>())
222d70938bbSJean Perier return toAffineExpr(intConstant.getInt());
223d70938bbSJean Perier if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) {
224d70938bbSJean Perier affineArgs.push_back(value);
225d70938bbSJean Perier if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) ||
226d70938bbSJean Perier isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp()))
227d70938bbSJean Perier return {mlir::getAffineDimExpr(dimCount++, value.getContext())};
228d70938bbSJean Perier return {mlir::getAffineSymbolExpr(symCount++, value.getContext())};
229d70938bbSJean Perier }
230d70938bbSJean Perier return {};
231d70938bbSJean Perier }
232d70938bbSJean Perier
fromCmpIOp__anonf1cdc1ed0311::AffineIfCondition233a54f4eaeSMogball void fromCmpIOp(mlir::arith::CmpIOp cmpOp) {
234feeee78aSJacques Pienaar auto lhsAffine = toAffineExpr(cmpOp.getLhs());
235feeee78aSJacques Pienaar auto rhsAffine = toAffineExpr(cmpOp.getRhs());
23686b8c1d9SKazu Hirata if (!lhsAffine || !rhsAffine)
237d70938bbSJean Perier return;
238*3356d72aSKazu Hirata auto constraintPair =
239*3356d72aSKazu Hirata constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine);
240d70938bbSJean Perier if (!constraintPair)
241d70938bbSJean Perier return;
242fac0fb4dSKazu Hirata integerSet = mlir::IntegerSet::get(
243fac0fb4dSKazu Hirata dimCount, symCount, {constraintPair->first}, {constraintPair->second});
244d70938bbSJean Perier }
245d70938bbSJean Perier
246d70938bbSJean Perier llvm::Optional<std::pair<AffineExpr, bool>>
constraint__anonf1cdc1ed0311::AffineIfCondition247a54f4eaeSMogball constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) {
248d70938bbSJean Perier switch (predicate) {
249a54f4eaeSMogball case mlir::arith::CmpIPredicate::slt:
250d70938bbSJean Perier return {std::make_pair(basic - 1, false)};
251a54f4eaeSMogball case mlir::arith::CmpIPredicate::sle:
252d70938bbSJean Perier return {std::make_pair(basic, false)};
253a54f4eaeSMogball case mlir::arith::CmpIPredicate::sgt:
254d70938bbSJean Perier return {std::make_pair(1 - basic, false)};
255a54f4eaeSMogball case mlir::arith::CmpIPredicate::sge:
256d70938bbSJean Perier return {std::make_pair(0 - basic, false)};
257a54f4eaeSMogball case mlir::arith::CmpIPredicate::eq:
258d70938bbSJean Perier return {std::make_pair(basic, true)};
259d70938bbSJean Perier default:
260d70938bbSJean Perier return {};
261d70938bbSJean Perier }
262d70938bbSJean Perier }
263d70938bbSJean Perier
264d70938bbSJean Perier llvm::SmallVector<mlir::Value> affineArgs;
265d70938bbSJean Perier llvm::Optional<mlir::IntegerSet> integerSet;
266d70938bbSJean Perier mlir::Value firCondition;
267d70938bbSJean Perier unsigned symCount{0u};
268d70938bbSJean Perier unsigned dimCount{0u};
269d70938bbSJean Perier };
270d70938bbSJean Perier } // namespace
271d70938bbSJean Perier
272d70938bbSJean Perier namespace {
273d70938bbSJean Perier /// Analysis for affine promotion of fir.if
274d70938bbSJean Perier struct AffineIfAnalysis {
275d70938bbSJean Perier AffineIfAnalysis() = default;
276d70938bbSJean Perier
AffineIfAnalysis__anonf1cdc1ed0411::AffineIfAnalysis277d70938bbSJean Perier explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa)
278d70938bbSJean Perier : legality(analyzeIf(op, afa)) {}
279d70938bbSJean Perier
canPromoteToAffine__anonf1cdc1ed0411::AffineIfAnalysis280d70938bbSJean Perier bool canPromoteToAffine() { return legality; }
281d70938bbSJean Perier
282d70938bbSJean Perier private:
analyzeIf__anonf1cdc1ed0411::AffineIfAnalysis283d70938bbSJean Perier bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) {
284d70938bbSJean Perier if (op.getNumResults() == 0)
285d70938bbSJean Perier return true;
286d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs()
287d70938bbSJean Perier << "AffineIfAnalysis: not promoting as op has results\n";);
288d70938bbSJean Perier return false;
289d70938bbSJean Perier }
290d70938bbSJean Perier
291d70938bbSJean Perier bool legality{};
292d70938bbSJean Perier };
293d70938bbSJean Perier } // namespace
294d70938bbSJean Perier
295d70938bbSJean Perier AffineIfAnalysis
getChildIfAnalysis(fir::IfOp op) const296d70938bbSJean Perier AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const {
297d70938bbSJean Perier auto it = ifAnalysisMap.find_as(op);
298d70938bbSJean Perier if (it == ifAnalysisMap.end()) {
299d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n";
300d70938bbSJean Perier op.dump(););
301d70938bbSJean Perier op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n");
302d70938bbSJean Perier return {};
303d70938bbSJean Perier }
304d70938bbSJean Perier return it->getSecond();
305d70938bbSJean Perier }
306d70938bbSJean Perier
307d70938bbSJean Perier /// AffineMap rewriting fir.array_coor operation to affine apply,
308d70938bbSJean Perier /// %dim = fir.gendim %lowerBound, %upperBound, %stride
309d70938bbSJean Perier /// %a = fir.array_coor %arr(%dim) %i
310d70938bbSJean Perier /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)>
createArrayIndexAffineMap(unsigned dimensions,MLIRContext * context)311d70938bbSJean Perier static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions,
312d70938bbSJean Perier MLIRContext *context) {
313d70938bbSJean Perier auto index = mlir::getAffineConstantExpr(0, context);
314d70938bbSJean Perier auto accuExtent = mlir::getAffineConstantExpr(1, context);
315d70938bbSJean Perier for (unsigned i = 0; i < dimensions; ++i) {
316d70938bbSJean Perier mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context),
317d70938bbSJean Perier lowerBound = mlir::getAffineSymbolExpr(i * 3, context),
318d70938bbSJean Perier currentExtent =
319d70938bbSJean Perier mlir::getAffineSymbolExpr(i * 3 + 1, context),
320d70938bbSJean Perier stride = mlir::getAffineSymbolExpr(i * 3 + 2, context),
321d70938bbSJean Perier currentPart = (idx * stride - lowerBound) * accuExtent;
322d70938bbSJean Perier index = currentPart + index;
323d70938bbSJean Perier accuExtent = accuExtent * currentExtent;
324d70938bbSJean Perier }
325d70938bbSJean Perier return mlir::AffineMap::get(dimensions, dimensions * 3, index);
326d70938bbSJean Perier }
327d70938bbSJean Perier
constantIntegerLike(const mlir::Value value)328d70938bbSJean Perier static Optional<int64_t> constantIntegerLike(const mlir::Value value) {
329a54f4eaeSMogball if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>())
3303012f35fSJacques Pienaar if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>())
331d70938bbSJean Perier return stepAttr.getInt();
332d70938bbSJean Perier return {};
333d70938bbSJean Perier }
334d70938bbSJean Perier
coordinateArrayElement(fir::ArrayCoorOp op)335d70938bbSJean Perier static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) {
336149ad3d5SShraiysh Vaishay if (auto refType =
337149ad3d5SShraiysh Vaishay op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) {
338d70938bbSJean Perier if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) {
339d70938bbSJean Perier return seqType.getEleTy();
340d70938bbSJean Perier }
341d70938bbSJean Perier }
342d70938bbSJean Perier op.emitError(
343d70938bbSJean Perier "AffineLoopConversion: array type in coordinate operation not valid\n");
344d70938bbSJean Perier return mlir::Type();
345d70938bbSJean Perier }
346d70938bbSJean Perier
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)347d70938bbSJean Perier static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape,
348d70938bbSJean Perier SmallVectorImpl<mlir::Value> &indexArgs,
349d70938bbSJean Perier mlir::PatternRewriter &rewriter) {
350a54f4eaeSMogball auto one = rewriter.create<mlir::arith::ConstantOp>(
351d70938bbSJean Perier acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
352149ad3d5SShraiysh Vaishay auto extents = shape.getExtents();
353d70938bbSJean Perier for (auto i = extents.begin(); i < extents.end(); i++) {
354d70938bbSJean Perier indexArgs.push_back(one);
355d70938bbSJean Perier indexArgs.push_back(*i);
356d70938bbSJean Perier indexArgs.push_back(one);
357d70938bbSJean Perier }
358d70938bbSJean Perier }
359d70938bbSJean Perier
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::ShapeShiftOp shape,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)360d70938bbSJean Perier static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape,
361d70938bbSJean Perier SmallVectorImpl<mlir::Value> &indexArgs,
362d70938bbSJean Perier mlir::PatternRewriter &rewriter) {
363a54f4eaeSMogball auto one = rewriter.create<mlir::arith::ConstantOp>(
364d70938bbSJean Perier acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1));
365149ad3d5SShraiysh Vaishay auto extents = shape.getPairs();
366d70938bbSJean Perier for (auto i = extents.begin(); i < extents.end();) {
367d70938bbSJean Perier indexArgs.push_back(*i++);
368d70938bbSJean Perier indexArgs.push_back(*i++);
369d70938bbSJean Perier indexArgs.push_back(one);
370d70938bbSJean Perier }
371d70938bbSJean Perier }
372d70938bbSJean Perier
populateIndexArgs(fir::ArrayCoorOp acoOp,fir::SliceOp slice,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)373d70938bbSJean Perier static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice,
374d70938bbSJean Perier SmallVectorImpl<mlir::Value> &indexArgs,
375d70938bbSJean Perier mlir::PatternRewriter &rewriter) {
376149ad3d5SShraiysh Vaishay auto extents = slice.getTriples();
377d70938bbSJean Perier for (auto i = extents.begin(); i < extents.end();) {
378d70938bbSJean Perier indexArgs.push_back(*i++);
379d70938bbSJean Perier indexArgs.push_back(*i++);
380d70938bbSJean Perier indexArgs.push_back(*i++);
381d70938bbSJean Perier }
382d70938bbSJean Perier }
383d70938bbSJean Perier
populateIndexArgs(fir::ArrayCoorOp acoOp,SmallVectorImpl<mlir::Value> & indexArgs,mlir::PatternRewriter & rewriter)384d70938bbSJean Perier static void populateIndexArgs(fir::ArrayCoorOp acoOp,
385d70938bbSJean Perier SmallVectorImpl<mlir::Value> &indexArgs,
386d70938bbSJean Perier mlir::PatternRewriter &rewriter) {
387149ad3d5SShraiysh Vaishay if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>())
388d70938bbSJean Perier return populateIndexArgs(acoOp, shape, indexArgs, rewriter);
389149ad3d5SShraiysh Vaishay if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>())
390d70938bbSJean Perier return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter);
391149ad3d5SShraiysh Vaishay if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>())
392d70938bbSJean Perier return populateIndexArgs(acoOp, slice, indexArgs, rewriter);
393d70938bbSJean Perier }
394d70938bbSJean Perier
395d70938bbSJean Perier /// Returns affine.apply and fir.convert from array_coor and gendims
396d70938bbSJean Perier static std::pair<mlir::AffineApplyOp, fir::ConvertOp>
createAffineOps(mlir::Value arrayRef,mlir::PatternRewriter & rewriter)397d70938bbSJean Perier createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) {
398d70938bbSJean Perier auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>();
399d70938bbSJean Perier auto affineMap =
400149ad3d5SShraiysh Vaishay createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext());
401d70938bbSJean Perier SmallVector<mlir::Value> indexArgs;
402149ad3d5SShraiysh Vaishay indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end());
403d70938bbSJean Perier
404d70938bbSJean Perier populateIndexArgs(acoOp, indexArgs, rewriter);
405d70938bbSJean Perier
406d70938bbSJean Perier auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(),
407d70938bbSJean Perier affineMap, indexArgs);
408d70938bbSJean Perier auto arrayElementType = coordinateArrayElement(acoOp);
409d70938bbSJean Perier auto newType = mlir::MemRefType::get({-1}, arrayElementType);
410149ad3d5SShraiysh Vaishay auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType,
411149ad3d5SShraiysh Vaishay acoOp.getMemref());
412d70938bbSJean Perier return std::make_pair(affineApply, arrayConvert);
413d70938bbSJean Perier }
414d70938bbSJean Perier
rewriteLoad(fir::LoadOp loadOp,mlir::PatternRewriter & rewriter)415d70938bbSJean Perier static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) {
416d70938bbSJean Perier rewriter.setInsertionPoint(loadOp);
417149ad3d5SShraiysh Vaishay auto affineOps = createAffineOps(loadOp.getMemref(), rewriter);
418d70938bbSJean Perier rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>(
419d70938bbSJean Perier loadOp, affineOps.second.getResult(), affineOps.first.getResult());
420d70938bbSJean Perier }
421d70938bbSJean Perier
rewriteStore(fir::StoreOp storeOp,mlir::PatternRewriter & rewriter)422d70938bbSJean Perier static void rewriteStore(fir::StoreOp storeOp,
423d70938bbSJean Perier mlir::PatternRewriter &rewriter) {
424d70938bbSJean Perier rewriter.setInsertionPoint(storeOp);
425149ad3d5SShraiysh Vaishay auto affineOps = createAffineOps(storeOp.getMemref(), rewriter);
426149ad3d5SShraiysh Vaishay rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.getValue(),
427d70938bbSJean Perier affineOps.second.getResult(),
428d70938bbSJean Perier affineOps.first.getResult());
429d70938bbSJean Perier }
430d70938bbSJean Perier
rewriteMemoryOps(Block * block,mlir::PatternRewriter & rewriter)431d70938bbSJean Perier static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) {
432d70938bbSJean Perier for (auto &bodyOp : block->getOperations()) {
433d70938bbSJean Perier if (isa<fir::LoadOp>(bodyOp))
434d70938bbSJean Perier rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter);
435d70938bbSJean Perier if (isa<fir::StoreOp>(bodyOp))
436d70938bbSJean Perier rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter);
437d70938bbSJean Perier }
438d70938bbSJean Perier }
439d70938bbSJean Perier
440d70938bbSJean Perier namespace {
441d70938bbSJean Perier /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to
442d70938bbSJean Perier /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir
443d70938bbSJean Perier /// loads and stores to affine.
444d70938bbSJean Perier class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
445d70938bbSJean Perier public:
446d70938bbSJean Perier using OpRewritePattern::OpRewritePattern;
AffineLoopConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)447d70938bbSJean Perier AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
448d70938bbSJean Perier : OpRewritePattern(context), functionAnalysis(afa) {}
449d70938bbSJean Perier
450d70938bbSJean Perier mlir::LogicalResult
matchAndRewrite(fir::DoLoopOp loop,mlir::PatternRewriter & rewriter) const451d70938bbSJean Perier matchAndRewrite(fir::DoLoopOp loop,
452d70938bbSJean Perier mlir::PatternRewriter &rewriter) const override {
453d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n";
454d70938bbSJean Perier loop.dump(););
455d70938bbSJean Perier LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
456d70938bbSJean Perier functionAnalysis.getChildLoopAnalysis(loop);
457d70938bbSJean Perier auto &loopOps = loop.getBody()->getOperations();
458d70938bbSJean Perier auto loopAndIndex = createAffineFor(loop, rewriter);
459d70938bbSJean Perier auto affineFor = loopAndIndex.first;
460d70938bbSJean Perier auto inductionVar = loopAndIndex.second;
461d70938bbSJean Perier
462d70938bbSJean Perier rewriter.startRootUpdate(affineFor.getOperation());
463d70938bbSJean Perier affineFor.getBody()->getOperations().splice(
464d70938bbSJean Perier std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
465d70938bbSJean Perier std::prev(loopOps.end()));
466d70938bbSJean Perier rewriter.finalizeRootUpdate(affineFor.getOperation());
467d70938bbSJean Perier
468d70938bbSJean Perier rewriter.startRootUpdate(loop.getOperation());
469d70938bbSJean Perier loop.getInductionVar().replaceAllUsesWith(inductionVar);
470d70938bbSJean Perier rewriter.finalizeRootUpdate(loop.getOperation());
471d70938bbSJean Perier
472d70938bbSJean Perier rewriteMemoryOps(affineFor.getBody(), rewriter);
473d70938bbSJean Perier
474d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
475d70938bbSJean Perier affineFor.dump(););
476d70938bbSJean Perier rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
477d70938bbSJean Perier return success();
478d70938bbSJean Perier }
479d70938bbSJean Perier
480d70938bbSJean Perier private:
481d70938bbSJean Perier std::pair<mlir::AffineForOp, mlir::Value>
createAffineFor(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const482d70938bbSJean Perier createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
483149ad3d5SShraiysh Vaishay if (auto constantStep = constantIntegerLike(op.getStep()))
484*3356d72aSKazu Hirata if (*constantStep > 0)
485*3356d72aSKazu Hirata return positiveConstantStep(op, *constantStep, rewriter);
486d70938bbSJean Perier return genericBounds(op, rewriter);
487d70938bbSJean Perier }
488d70938bbSJean Perier
489d70938bbSJean Perier // when step for the loop is positive compile time constant
490d70938bbSJean Perier std::pair<mlir::AffineForOp, mlir::Value>
positiveConstantStep(fir::DoLoopOp op,int64_t step,mlir::PatternRewriter & rewriter) const491d70938bbSJean Perier positiveConstantStep(fir::DoLoopOp op, int64_t step,
492d70938bbSJean Perier mlir::PatternRewriter &rewriter) const {
493d70938bbSJean Perier auto affineFor = rewriter.create<mlir::AffineForOp>(
494149ad3d5SShraiysh Vaishay op.getLoc(), ValueRange(op.getLowerBound()),
495d70938bbSJean Perier mlir::AffineMap::get(0, 1,
496d70938bbSJean Perier mlir::getAffineSymbolExpr(0, op.getContext())),
497149ad3d5SShraiysh Vaishay ValueRange(op.getUpperBound()),
498d70938bbSJean Perier mlir::AffineMap::get(0, 1,
499d70938bbSJean Perier 1 + mlir::getAffineSymbolExpr(0, op.getContext())),
500d70938bbSJean Perier step);
501d70938bbSJean Perier return std::make_pair(affineFor, affineFor.getInductionVar());
502d70938bbSJean Perier }
503d70938bbSJean Perier
504d70938bbSJean Perier std::pair<mlir::AffineForOp, mlir::Value>
genericBounds(fir::DoLoopOp op,mlir::PatternRewriter & rewriter) const505d70938bbSJean Perier genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const {
506d70938bbSJean Perier auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext());
507d70938bbSJean Perier auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext());
508d70938bbSJean Perier auto step = mlir::getAffineSymbolExpr(2, op.getContext());
509d70938bbSJean Perier mlir::AffineMap upperBoundMap = mlir::AffineMap::get(
510d70938bbSJean Perier 0, 3, (upperBound - lowerBound + step).floorDiv(step));
511d70938bbSJean Perier auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>(
512d70938bbSJean Perier op.getLoc(), upperBoundMap,
513149ad3d5SShraiysh Vaishay ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()}));
514d70938bbSJean Perier auto actualIndexMap = mlir::AffineMap::get(
515d70938bbSJean Perier 1, 2,
516d70938bbSJean Perier (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) *
517d70938bbSJean Perier mlir::getAffineSymbolExpr(1, op.getContext()));
518d70938bbSJean Perier
519d70938bbSJean Perier auto affineFor = rewriter.create<mlir::AffineForOp>(
520d70938bbSJean Perier op.getLoc(), ValueRange(),
521d70938bbSJean Perier AffineMap::getConstantMap(0, op.getContext()),
522d70938bbSJean Perier genericUpperBound.getResult(),
523d70938bbSJean Perier mlir::AffineMap::get(0, 1,
524d70938bbSJean Perier 1 + mlir::getAffineSymbolExpr(0, op.getContext())),
525d70938bbSJean Perier 1);
526d70938bbSJean Perier rewriter.setInsertionPointToStart(affineFor.getBody());
527d70938bbSJean Perier auto actualIndex = rewriter.create<mlir::AffineApplyOp>(
528d70938bbSJean Perier op.getLoc(), actualIndexMap,
529149ad3d5SShraiysh Vaishay ValueRange(
530149ad3d5SShraiysh Vaishay {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()}));
531d70938bbSJean Perier return std::make_pair(affineFor, actualIndex.getResult());
532d70938bbSJean Perier }
533d70938bbSJean Perier
534d70938bbSJean Perier AffineFunctionAnalysis &functionAnalysis;
535d70938bbSJean Perier };
536d70938bbSJean Perier
537d70938bbSJean Perier /// Convert `fir.if` to `affine.if`.
538d70938bbSJean Perier class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
539d70938bbSJean Perier public:
540d70938bbSJean Perier using OpRewritePattern::OpRewritePattern;
AffineIfConversion(mlir::MLIRContext * context,AffineFunctionAnalysis & afa)541d70938bbSJean Perier AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
542d70938bbSJean Perier : OpRewritePattern(context) {}
543d70938bbSJean Perier mlir::LogicalResult
matchAndRewrite(fir::IfOp op,mlir::PatternRewriter & rewriter) const544d70938bbSJean Perier matchAndRewrite(fir::IfOp op,
545d70938bbSJean Perier mlir::PatternRewriter &rewriter) const override {
546d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n";
547d70938bbSJean Perier op.dump(););
548149ad3d5SShraiysh Vaishay auto &ifOps = op.getThenRegion().front().getOperations();
549149ad3d5SShraiysh Vaishay auto affineCondition = AffineIfCondition(op.getCondition());
550d70938bbSJean Perier if (!affineCondition.hasIntegerSet()) {
551d70938bbSJean Perier LLVM_DEBUG(
552d70938bbSJean Perier llvm::dbgs()
553d70938bbSJean Perier << "AffineIfConversion: couldn't calculate affine condition\n";);
554d70938bbSJean Perier return failure();
555d70938bbSJean Perier }
556d70938bbSJean Perier auto affineIf = rewriter.create<mlir::AffineIfOp>(
557d70938bbSJean Perier op.getLoc(), affineCondition.getIntegerSet(),
558149ad3d5SShraiysh Vaishay affineCondition.getAffineArgs(), !op.getElseRegion().empty());
559d70938bbSJean Perier rewriter.startRootUpdate(affineIf);
560d70938bbSJean Perier affineIf.getThenBlock()->getOperations().splice(
561d70938bbSJean Perier std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(),
562d70938bbSJean Perier std::prev(ifOps.end()));
563149ad3d5SShraiysh Vaishay if (!op.getElseRegion().empty()) {
564149ad3d5SShraiysh Vaishay auto &otherOps = op.getElseRegion().front().getOperations();
565d70938bbSJean Perier affineIf.getElseBlock()->getOperations().splice(
566d70938bbSJean Perier std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(),
567d70938bbSJean Perier std::prev(otherOps.end()));
568d70938bbSJean Perier }
569d70938bbSJean Perier rewriter.finalizeRootUpdate(affineIf);
570d70938bbSJean Perier rewriteMemoryOps(affineIf.getBody(), rewriter);
571d70938bbSJean Perier
572d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n";
573d70938bbSJean Perier affineIf.dump(););
574d70938bbSJean Perier rewriter.replaceOp(op, affineIf.getOperation()->getResults());
575d70938bbSJean Perier return success();
576d70938bbSJean Perier }
577d70938bbSJean Perier };
578d70938bbSJean Perier
579d70938bbSJean Perier /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
580d70938bbSJean Perier /// where such a promotion is possible.
581d70938bbSJean Perier class AffineDialectPromotion
582d70938bbSJean Perier : public AffineDialectPromotionBase<AffineDialectPromotion> {
583d70938bbSJean Perier public:
runOnOperation()584196c4279SRiver Riddle void runOnOperation() override {
585d70938bbSJean Perier
586d70938bbSJean Perier auto *context = &getContext();
587196c4279SRiver Riddle auto function = getOperation();
588d70938bbSJean Perier markAllAnalysesPreserved();
589d70938bbSJean Perier auto functionAnalysis = AffineFunctionAnalysis(function);
5909f85c198SRiver Riddle mlir::RewritePatternSet patterns(context);
591d70938bbSJean Perier patterns.insert<AffineIfConversion>(context, functionAnalysis);
592d70938bbSJean Perier patterns.insert<AffineLoopConversion>(context, functionAnalysis);
593d70938bbSJean Perier mlir::ConversionTarget target = *context;
594a54f4eaeSMogball target.addLegalDialect<
595a54f4eaeSMogball mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect,
59623aa5a74SRiver Riddle mlir::arith::ArithmeticDialect, mlir::func::FuncDialect>();
597d70938bbSJean Perier target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) {
598d70938bbSJean Perier return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine());
599d70938bbSJean Perier });
600d70938bbSJean Perier target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis](
601d70938bbSJean Perier fir::DoLoopOp op) {
602d70938bbSJean Perier return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine());
603d70938bbSJean Perier });
604d70938bbSJean Perier
605d70938bbSJean Perier LLVM_DEBUG(llvm::dbgs()
606d70938bbSJean Perier << "AffineDialectPromotion: running promotion on: \n";
607d70938bbSJean Perier function.print(llvm::dbgs()););
608d70938bbSJean Perier // apply the patterns
609d70938bbSJean Perier if (mlir::failed(mlir::applyPartialConversion(function, target,
610d70938bbSJean Perier std::move(patterns)))) {
611d70938bbSJean Perier mlir::emitError(mlir::UnknownLoc::get(context),
612d70938bbSJean Perier "error in converting to affine dialect\n");
613d70938bbSJean Perier signalPassFailure();
614d70938bbSJean Perier }
615d70938bbSJean Perier }
616d70938bbSJean Perier };
617d70938bbSJean Perier } // namespace
618d70938bbSJean Perier
619d70938bbSJean Perier /// Convert FIR loop constructs to the Affine dialect
createPromoteToAffinePass()620d70938bbSJean Perier std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() {
621d70938bbSJean Perier return std::make_unique<AffineDialectPromotion>();
622d70938bbSJean Perier }
623