1 //===-- AffinePromotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Dialect/FIRDialect.h" 11 #include "flang/Optimizer/Dialect/FIROps.h" 12 #include "flang/Optimizer/Dialect/FIRType.h" 13 #include "flang/Optimizer/Transforms/Passes.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/SCF/SCF.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/BuiltinAttributes.h" 18 #include "mlir/IR/IntegerSet.h" 19 #include "mlir/IR/Visitors.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/ADT/Optional.h" 23 #include "llvm/Support/Debug.h" 24 25 #define DEBUG_TYPE "flang-affine-promotion" 26 27 using namespace fir; 28 29 namespace { 30 struct AffineLoopAnalysis; 31 struct AffineIfAnalysis; 32 33 /// Stores analysis objects for all loops and if operations inside a function 34 /// these analysis are used twice, first for marking operations for rewrite and 35 /// second when doing rewrite. 36 struct AffineFunctionAnalysis { 37 explicit AffineFunctionAnalysis(mlir::FuncOp funcOp) { 38 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>()) 39 loopAnalysisMap.try_emplace(op, op, *this); 40 } 41 42 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; 43 44 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; 45 46 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; 47 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; 48 }; 49 } // namespace 50 51 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { 52 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) { 53 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) 54 return true; 55 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " 56 "loop induction variable (owner not loopOp)\n"; 57 op->dump()); 58 return false; 59 } 60 LLVM_DEBUG( 61 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " 62 "induction variable (not a block argument)\n"; 63 op->dump(); coordinate.getDefiningOp()->dump()); 64 return false; 65 } 66 67 namespace { 68 struct AffineLoopAnalysis { 69 AffineLoopAnalysis() = default; 70 71 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) 72 : legality(analyzeLoop(op, afa)) {} 73 74 bool canPromoteToAffine() { return legality; } 75 76 private: 77 bool analyzeBody(fir::DoLoopOp loopOperation, 78 AffineFunctionAnalysis &functionAnalysis) { 79 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { 80 auto analysis = functionAnalysis.loopAnalysisMap 81 .try_emplace(loopOp, loopOp, functionAnalysis) 82 .first->getSecond(); 83 if (!analysis.canPromoteToAffine()) 84 return false; 85 } 86 for (auto ifOp : loopOperation.getOps<fir::IfOp>()) 87 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); 88 return true; 89 } 90 91 bool analyzeLoop(fir::DoLoopOp loopOperation, 92 AffineFunctionAnalysis &functionAnalysis) { 93 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump();); 94 return analyzeMemoryAccess(loopOperation) && 95 analyzeBody(loopOperation, functionAnalysis); 96 } 97 98 bool analyzeReference(mlir::Value memref, mlir::Operation *op) { 99 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { 100 if (acoOp.memref().getType().isa<fir::BoxType>()) { 101 // TODO: Look if and how fir.box can be promoted to affine. 102 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " 103 "array memory operation uses fir.box\n"; 104 op->dump(); acoOp.dump();); 105 return false; 106 } 107 bool canPromote = true; 108 for (auto coordinate : acoOp.indices()) 109 canPromote = canPromote && analyzeCoordinate(coordinate, op); 110 return canPromote; 111 } 112 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { 113 LLVM_DEBUG(llvm::dbgs() 114 << "AffineLoopAnalysis: cannot promote loop, " 115 "array memory operation uses non ArrayCoorOp\n"; 116 op->dump(); coOp.dump();); 117 118 return false; 119 } 120 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " 121 "reference for array load\n"; 122 op->dump();); 123 return false; 124 } 125 126 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { 127 for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) 128 if (!analyzeReference(loadOp.memref(), loadOp)) 129 return false; 130 for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) 131 if (!analyzeReference(storeOp.memref(), storeOp)) 132 return false; 133 return true; 134 } 135 136 bool legality{}; 137 }; 138 } // namespace 139 140 AffineLoopAnalysis 141 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { 142 auto it = loopAnalysisMap.find_as(op); 143 if (it == loopAnalysisMap.end()) { 144 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 145 op.dump();); 146 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n"); 147 return {}; 148 } 149 return it->getSecond(); 150 } 151 152 namespace { 153 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the 154 /// final number of symbols and dimensions of the affine map. Integer set if 155 /// possible is in Optional IntegerSet. 156 struct AffineIfCondition { 157 using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>; 158 159 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { 160 if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) 161 fromCmpIOp(condDef); 162 } 163 164 bool hasIntegerSet() const { return integerSet.hasValue(); } 165 166 mlir::IntegerSet getIntegerSet() const { 167 assert(hasIntegerSet() && "integer set is missing"); 168 return integerSet.getValue(); 169 } 170 171 mlir::ValueRange getAffineArgs() const { return affineArgs; } 172 173 private: 174 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, 175 mlir::Value rhs) { 176 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); 177 } 178 179 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, 180 MaybeAffineExpr rhs) { 181 if (lhs.hasValue() && rhs.hasValue()) 182 return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue()); 183 return {}; 184 } 185 186 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } 187 188 MaybeAffineExpr toAffineExpr(int64_t value) { 189 return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; 190 } 191 192 /// Returns an AffineExpr if it is a result of operations that can be done 193 /// in an affine expression, this includes -, +, *, rem, constant. 194 /// block arguments of a loopOp or forOp are used as dimensions 195 MaybeAffineExpr toAffineExpr(mlir::Value value) { 196 if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) 197 return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()), 198 affineBinaryOp(mlir::AffineExprKind::Mul, 199 toAffineExpr(op.rhs()), 200 toAffineExpr(-1))); 201 if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) 202 return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs()); 203 if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) 204 return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs()); 205 if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) 206 return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs()); 207 if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) 208 if (auto intConstant = op.value().dyn_cast<IntegerAttr>()) 209 return toAffineExpr(intConstant.getInt()); 210 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) { 211 affineArgs.push_back(value); 212 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || 213 isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp())) 214 return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; 215 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; 216 } 217 return {}; 218 } 219 220 void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { 221 auto lhsAffine = toAffineExpr(cmpOp.lhs()); 222 auto rhsAffine = toAffineExpr(cmpOp.rhs()); 223 if (!lhsAffine.hasValue() || !rhsAffine.hasValue()) 224 return; 225 auto constraintPair = constraint( 226 cmpOp.predicate(), rhsAffine.getValue() - lhsAffine.getValue()); 227 if (!constraintPair) 228 return; 229 integerSet = mlir::IntegerSet::get(dimCount, symCount, 230 {constraintPair.getValue().first}, 231 {constraintPair.getValue().second}); 232 return; 233 } 234 235 llvm::Optional<std::pair<AffineExpr, bool>> 236 constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { 237 switch (predicate) { 238 case mlir::arith::CmpIPredicate::slt: 239 return {std::make_pair(basic - 1, false)}; 240 case mlir::arith::CmpIPredicate::sle: 241 return {std::make_pair(basic, false)}; 242 case mlir::arith::CmpIPredicate::sgt: 243 return {std::make_pair(1 - basic, false)}; 244 case mlir::arith::CmpIPredicate::sge: 245 return {std::make_pair(0 - basic, false)}; 246 case mlir::arith::CmpIPredicate::eq: 247 return {std::make_pair(basic, true)}; 248 default: 249 return {}; 250 } 251 } 252 253 llvm::SmallVector<mlir::Value> affineArgs; 254 llvm::Optional<mlir::IntegerSet> integerSet; 255 mlir::Value firCondition; 256 unsigned symCount{0u}; 257 unsigned dimCount{0u}; 258 }; 259 } // namespace 260 261 namespace { 262 /// Analysis for affine promotion of fir.if 263 struct AffineIfAnalysis { 264 AffineIfAnalysis() = default; 265 266 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) 267 : legality(analyzeIf(op, afa)) {} 268 269 bool canPromoteToAffine() { return legality; } 270 271 private: 272 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { 273 if (op.getNumResults() == 0) 274 return true; 275 LLVM_DEBUG(llvm::dbgs() 276 << "AffineIfAnalysis: not promoting as op has results\n";); 277 return false; 278 } 279 280 bool legality{}; 281 }; 282 } // namespace 283 284 AffineIfAnalysis 285 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { 286 auto it = ifAnalysisMap.find_as(op); 287 if (it == ifAnalysisMap.end()) { 288 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 289 op.dump();); 290 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n"); 291 return {}; 292 } 293 return it->getSecond(); 294 } 295 296 /// AffineMap rewriting fir.array_coor operation to affine apply, 297 /// %dim = fir.gendim %lowerBound, %upperBound, %stride 298 /// %a = fir.array_coor %arr(%dim) %i 299 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> 300 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, 301 MLIRContext *context) { 302 auto index = mlir::getAffineConstantExpr(0, context); 303 auto accuExtent = mlir::getAffineConstantExpr(1, context); 304 for (unsigned i = 0; i < dimensions; ++i) { 305 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), 306 lowerBound = mlir::getAffineSymbolExpr(i * 3, context), 307 currentExtent = 308 mlir::getAffineSymbolExpr(i * 3 + 1, context), 309 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), 310 currentPart = (idx * stride - lowerBound) * accuExtent; 311 index = currentPart + index; 312 accuExtent = accuExtent * currentExtent; 313 } 314 return mlir::AffineMap::get(dimensions, dimensions * 3, index); 315 } 316 317 static Optional<int64_t> constantIntegerLike(const mlir::Value value) { 318 if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) 319 if (auto stepAttr = definition.value().dyn_cast<IntegerAttr>()) 320 return stepAttr.getInt(); 321 return {}; 322 } 323 324 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { 325 if (auto refType = op.memref().getType().dyn_cast_or_null<ReferenceType>()) { 326 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) { 327 return seqType.getEleTy(); 328 } 329 } 330 op.emitError( 331 "AffineLoopConversion: array type in coordinate operation not valid\n"); 332 return mlir::Type(); 333 } 334 335 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, 336 SmallVectorImpl<mlir::Value> &indexArgs, 337 mlir::PatternRewriter &rewriter) { 338 auto one = rewriter.create<mlir::arith::ConstantOp>( 339 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 340 auto extents = shape.extents(); 341 for (auto i = extents.begin(); i < extents.end(); i++) { 342 indexArgs.push_back(one); 343 indexArgs.push_back(*i); 344 indexArgs.push_back(one); 345 } 346 } 347 348 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, 349 SmallVectorImpl<mlir::Value> &indexArgs, 350 mlir::PatternRewriter &rewriter) { 351 auto one = rewriter.create<mlir::arith::ConstantOp>( 352 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 353 auto extents = shape.pairs(); 354 for (auto i = extents.begin(); i < extents.end();) { 355 indexArgs.push_back(*i++); 356 indexArgs.push_back(*i++); 357 indexArgs.push_back(one); 358 } 359 } 360 361 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, 362 SmallVectorImpl<mlir::Value> &indexArgs, 363 mlir::PatternRewriter &rewriter) { 364 auto extents = slice.triples(); 365 for (auto i = extents.begin(); i < extents.end();) { 366 indexArgs.push_back(*i++); 367 indexArgs.push_back(*i++); 368 indexArgs.push_back(*i++); 369 } 370 } 371 372 static void populateIndexArgs(fir::ArrayCoorOp acoOp, 373 SmallVectorImpl<mlir::Value> &indexArgs, 374 mlir::PatternRewriter &rewriter) { 375 if (auto shape = acoOp.shape().getDefiningOp<ShapeOp>()) 376 return populateIndexArgs(acoOp, shape, indexArgs, rewriter); 377 if (auto shapeShift = acoOp.shape().getDefiningOp<ShapeShiftOp>()) 378 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); 379 if (auto slice = acoOp.shape().getDefiningOp<SliceOp>()) 380 return populateIndexArgs(acoOp, slice, indexArgs, rewriter); 381 return; 382 } 383 384 /// Returns affine.apply and fir.convert from array_coor and gendims 385 static std::pair<mlir::AffineApplyOp, fir::ConvertOp> 386 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { 387 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); 388 auto affineMap = 389 createArrayIndexAffineMap(acoOp.indices().size(), acoOp.getContext()); 390 SmallVector<mlir::Value> indexArgs; 391 indexArgs.append(acoOp.indices().begin(), acoOp.indices().end()); 392 393 populateIndexArgs(acoOp, indexArgs, rewriter); 394 395 auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(), 396 affineMap, indexArgs); 397 auto arrayElementType = coordinateArrayElement(acoOp); 398 auto newType = mlir::MemRefType::get({-1}, arrayElementType); 399 auto arrayConvert = 400 rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, acoOp.memref()); 401 return std::make_pair(affineApply, arrayConvert); 402 } 403 404 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { 405 rewriter.setInsertionPoint(loadOp); 406 auto affineOps = createAffineOps(loadOp.memref(), rewriter); 407 rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>( 408 loadOp, affineOps.second.getResult(), affineOps.first.getResult()); 409 } 410 411 static void rewriteStore(fir::StoreOp storeOp, 412 mlir::PatternRewriter &rewriter) { 413 rewriter.setInsertionPoint(storeOp); 414 auto affineOps = createAffineOps(storeOp.memref(), rewriter); 415 rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.value(), 416 affineOps.second.getResult(), 417 affineOps.first.getResult()); 418 } 419 420 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { 421 for (auto &bodyOp : block->getOperations()) { 422 if (isa<fir::LoadOp>(bodyOp)) 423 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); 424 if (isa<fir::StoreOp>(bodyOp)) 425 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); 426 } 427 } 428 429 namespace { 430 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to 431 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir 432 /// loads and stores to affine. 433 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { 434 public: 435 using OpRewritePattern::OpRewritePattern; 436 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 437 : OpRewritePattern(context), functionAnalysis(afa) {} 438 439 mlir::LogicalResult 440 matchAndRewrite(fir::DoLoopOp loop, 441 mlir::PatternRewriter &rewriter) const override { 442 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; 443 loop.dump();); 444 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = 445 functionAnalysis.getChildLoopAnalysis(loop); 446 auto &loopOps = loop.getBody()->getOperations(); 447 auto loopAndIndex = createAffineFor(loop, rewriter); 448 auto affineFor = loopAndIndex.first; 449 auto inductionVar = loopAndIndex.second; 450 451 rewriter.startRootUpdate(affineFor.getOperation()); 452 affineFor.getBody()->getOperations().splice( 453 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), 454 std::prev(loopOps.end())); 455 rewriter.finalizeRootUpdate(affineFor.getOperation()); 456 457 rewriter.startRootUpdate(loop.getOperation()); 458 loop.getInductionVar().replaceAllUsesWith(inductionVar); 459 rewriter.finalizeRootUpdate(loop.getOperation()); 460 461 rewriteMemoryOps(affineFor.getBody(), rewriter); 462 463 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n"; 464 affineFor.dump();); 465 rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); 466 return success(); 467 } 468 469 private: 470 std::pair<mlir::AffineForOp, mlir::Value> 471 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 472 if (auto constantStep = constantIntegerLike(op.step())) 473 if (constantStep.getValue() > 0) 474 return positiveConstantStep(op, constantStep.getValue(), rewriter); 475 return genericBounds(op, rewriter); 476 } 477 478 // when step for the loop is positive compile time constant 479 std::pair<mlir::AffineForOp, mlir::Value> 480 positiveConstantStep(fir::DoLoopOp op, int64_t step, 481 mlir::PatternRewriter &rewriter) const { 482 auto affineFor = rewriter.create<mlir::AffineForOp>( 483 op.getLoc(), ValueRange(op.lowerBound()), 484 mlir::AffineMap::get(0, 1, 485 mlir::getAffineSymbolExpr(0, op.getContext())), 486 ValueRange(op.upperBound()), 487 mlir::AffineMap::get(0, 1, 488 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 489 step); 490 return std::make_pair(affineFor, affineFor.getInductionVar()); 491 } 492 493 std::pair<mlir::AffineForOp, mlir::Value> 494 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 495 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); 496 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); 497 auto step = mlir::getAffineSymbolExpr(2, op.getContext()); 498 mlir::AffineMap upperBoundMap = mlir::AffineMap::get( 499 0, 3, (upperBound - lowerBound + step).floorDiv(step)); 500 auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>( 501 op.getLoc(), upperBoundMap, 502 ValueRange({op.lowerBound(), op.upperBound(), op.step()})); 503 auto actualIndexMap = mlir::AffineMap::get( 504 1, 2, 505 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * 506 mlir::getAffineSymbolExpr(1, op.getContext())); 507 508 auto affineFor = rewriter.create<mlir::AffineForOp>( 509 op.getLoc(), ValueRange(), 510 AffineMap::getConstantMap(0, op.getContext()), 511 genericUpperBound.getResult(), 512 mlir::AffineMap::get(0, 1, 513 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 514 1); 515 rewriter.setInsertionPointToStart(affineFor.getBody()); 516 auto actualIndex = rewriter.create<mlir::AffineApplyOp>( 517 op.getLoc(), actualIndexMap, 518 ValueRange({affineFor.getInductionVar(), op.lowerBound(), op.step()})); 519 return std::make_pair(affineFor, actualIndex.getResult()); 520 } 521 522 AffineFunctionAnalysis &functionAnalysis; 523 }; 524 525 /// Convert `fir.if` to `affine.if`. 526 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { 527 public: 528 using OpRewritePattern::OpRewritePattern; 529 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 530 : OpRewritePattern(context) {} 531 mlir::LogicalResult 532 matchAndRewrite(fir::IfOp op, 533 mlir::PatternRewriter &rewriter) const override { 534 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; 535 op.dump();); 536 auto &ifOps = op.thenRegion().front().getOperations(); 537 auto affineCondition = AffineIfCondition(op.condition()); 538 if (!affineCondition.hasIntegerSet()) { 539 LLVM_DEBUG( 540 llvm::dbgs() 541 << "AffineIfConversion: couldn't calculate affine condition\n";); 542 return failure(); 543 } 544 auto affineIf = rewriter.create<mlir::AffineIfOp>( 545 op.getLoc(), affineCondition.getIntegerSet(), 546 affineCondition.getAffineArgs(), !op.elseRegion().empty()); 547 rewriter.startRootUpdate(affineIf); 548 affineIf.getThenBlock()->getOperations().splice( 549 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), 550 std::prev(ifOps.end())); 551 if (!op.elseRegion().empty()) { 552 auto &otherOps = op.elseRegion().front().getOperations(); 553 affineIf.getElseBlock()->getOperations().splice( 554 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), 555 std::prev(otherOps.end())); 556 } 557 rewriter.finalizeRootUpdate(affineIf); 558 rewriteMemoryOps(affineIf.getBody(), rewriter); 559 560 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n"; 561 affineIf.dump();); 562 rewriter.replaceOp(op, affineIf.getOperation()->getResults()); 563 return success(); 564 } 565 }; 566 567 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases 568 /// where such a promotion is possible. 569 class AffineDialectPromotion 570 : public AffineDialectPromotionBase<AffineDialectPromotion> { 571 public: 572 void runOnFunction() override { 573 574 auto *context = &getContext(); 575 auto function = getFunction(); 576 markAllAnalysesPreserved(); 577 auto functionAnalysis = AffineFunctionAnalysis(function); 578 mlir::OwningRewritePatternList patterns(context); 579 patterns.insert<AffineIfConversion>(context, functionAnalysis); 580 patterns.insert<AffineLoopConversion>(context, functionAnalysis); 581 mlir::ConversionTarget target = *context; 582 target.addLegalDialect< 583 mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect, 584 mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect>(); 585 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { 586 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); 587 }); 588 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( 589 fir::DoLoopOp op) { 590 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); 591 }); 592 593 LLVM_DEBUG(llvm::dbgs() 594 << "AffineDialectPromotion: running promotion on: \n"; 595 function.print(llvm::dbgs());); 596 // apply the patterns 597 if (mlir::failed(mlir::applyPartialConversion(function, target, 598 std::move(patterns)))) { 599 mlir::emitError(mlir::UnknownLoc::get(context), 600 "error in converting to affine dialect\n"); 601 signalPassFailure(); 602 } 603 } 604 }; 605 } // namespace 606 607 /// Convert FIR loop constructs to the Affine dialect 608 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { 609 return std::make_unique<AffineDialectPromotion>(); 610 } 611