1 //===-- AffinePromotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that promote FIR loops operations 10 // to affine dialect operations. 11 // It is not part of the production pipeline and would need more work in order 12 // to be used in production. 13 // More information can be found in this presentation: 14 // https://slides.com/rajanwalia/deck 15 // 16 //===----------------------------------------------------------------------===// 17 18 #include "PassDetail.h" 19 #include "flang/Optimizer/Dialect/FIRDialect.h" 20 #include "flang/Optimizer/Dialect/FIROps.h" 21 #include "flang/Optimizer/Dialect/FIRType.h" 22 #include "flang/Optimizer/Transforms/Passes.h" 23 #include "mlir/Dialect/Affine/IR/AffineOps.h" 24 #include "mlir/Dialect/SCF/SCF.h" 25 #include "mlir/Dialect/StandardOps/IR/Ops.h" 26 #include "mlir/IR/BuiltinAttributes.h" 27 #include "mlir/IR/IntegerSet.h" 28 #include "mlir/IR/Visitors.h" 29 #include "mlir/Transforms/DialectConversion.h" 30 #include "llvm/ADT/DenseMap.h" 31 #include "llvm/ADT/Optional.h" 32 #include "llvm/Support/Debug.h" 33 34 #define DEBUG_TYPE "flang-affine-promotion" 35 36 using namespace fir; 37 38 namespace { 39 struct AffineLoopAnalysis; 40 struct AffineIfAnalysis; 41 42 /// Stores analysis objects for all loops and if operations inside a function 43 /// these analysis are used twice, first for marking operations for rewrite and 44 /// second when doing rewrite. 45 struct AffineFunctionAnalysis { 46 explicit AffineFunctionAnalysis(mlir::FuncOp funcOp) { 47 for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>()) 48 loopAnalysisMap.try_emplace(op, op, *this); 49 } 50 51 AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; 52 53 AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; 54 55 llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; 56 llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; 57 }; 58 } // namespace 59 60 static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { 61 if (auto blockArg = coordinate.dyn_cast<mlir::BlockArgument>()) { 62 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) 63 return true; 64 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " 65 "loop induction variable (owner not loopOp)\n"; 66 op->dump()); 67 return false; 68 } 69 LLVM_DEBUG( 70 llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " 71 "induction variable (not a block argument)\n"; 72 op->dump(); coordinate.getDefiningOp()->dump()); 73 return false; 74 } 75 76 namespace { 77 struct AffineLoopAnalysis { 78 AffineLoopAnalysis() = default; 79 80 explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) 81 : legality(analyzeLoop(op, afa)) {} 82 83 bool canPromoteToAffine() { return legality; } 84 85 private: 86 bool analyzeBody(fir::DoLoopOp loopOperation, 87 AffineFunctionAnalysis &functionAnalysis) { 88 for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { 89 auto analysis = functionAnalysis.loopAnalysisMap 90 .try_emplace(loopOp, loopOp, functionAnalysis) 91 .first->getSecond(); 92 if (!analysis.canPromoteToAffine()) 93 return false; 94 } 95 for (auto ifOp : loopOperation.getOps<fir::IfOp>()) 96 functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); 97 return true; 98 } 99 100 bool analyzeLoop(fir::DoLoopOp loopOperation, 101 AffineFunctionAnalysis &functionAnalysis) { 102 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump();); 103 return analyzeMemoryAccess(loopOperation) && 104 analyzeBody(loopOperation, functionAnalysis); 105 } 106 107 bool analyzeReference(mlir::Value memref, mlir::Operation *op) { 108 if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { 109 if (acoOp.memref().getType().isa<fir::BoxType>()) { 110 // TODO: Look if and how fir.box can be promoted to affine. 111 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " 112 "array memory operation uses fir.box\n"; 113 op->dump(); acoOp.dump();); 114 return false; 115 } 116 bool canPromote = true; 117 for (auto coordinate : acoOp.indices()) 118 canPromote = canPromote && analyzeCoordinate(coordinate, op); 119 return canPromote; 120 } 121 if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { 122 LLVM_DEBUG(llvm::dbgs() 123 << "AffineLoopAnalysis: cannot promote loop, " 124 "array memory operation uses non ArrayCoorOp\n"; 125 op->dump(); coOp.dump();); 126 127 return false; 128 } 129 LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " 130 "reference for array load\n"; 131 op->dump();); 132 return false; 133 } 134 135 bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { 136 for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) 137 if (!analyzeReference(loadOp.memref(), loadOp)) 138 return false; 139 for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) 140 if (!analyzeReference(storeOp.memref(), storeOp)) 141 return false; 142 return true; 143 } 144 145 bool legality{}; 146 }; 147 } // namespace 148 149 AffineLoopAnalysis 150 AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { 151 auto it = loopAnalysisMap.find_as(op); 152 if (it == loopAnalysisMap.end()) { 153 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 154 op.dump();); 155 op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n"); 156 return {}; 157 } 158 return it->getSecond(); 159 } 160 161 namespace { 162 /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the 163 /// final number of symbols and dimensions of the affine map. Integer set if 164 /// possible is in Optional IntegerSet. 165 struct AffineIfCondition { 166 using MaybeAffineExpr = llvm::Optional<mlir::AffineExpr>; 167 168 explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { 169 if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) 170 fromCmpIOp(condDef); 171 } 172 173 bool hasIntegerSet() const { return integerSet.hasValue(); } 174 175 mlir::IntegerSet getIntegerSet() const { 176 assert(hasIntegerSet() && "integer set is missing"); 177 return integerSet.getValue(); 178 } 179 180 mlir::ValueRange getAffineArgs() const { return affineArgs; } 181 182 private: 183 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, 184 mlir::Value rhs) { 185 return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); 186 } 187 188 MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, 189 MaybeAffineExpr rhs) { 190 if (lhs.hasValue() && rhs.hasValue()) 191 return mlir::getAffineBinaryOpExpr(kind, lhs.getValue(), rhs.getValue()); 192 return {}; 193 } 194 195 MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } 196 197 MaybeAffineExpr toAffineExpr(int64_t value) { 198 return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; 199 } 200 201 /// Returns an AffineExpr if it is a result of operations that can be done 202 /// in an affine expression, this includes -, +, *, rem, constant. 203 /// block arguments of a loopOp or forOp are used as dimensions 204 MaybeAffineExpr toAffineExpr(mlir::Value value) { 205 if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) 206 return affineBinaryOp(mlir::AffineExprKind::Add, toAffineExpr(op.lhs()), 207 affineBinaryOp(mlir::AffineExprKind::Mul, 208 toAffineExpr(op.rhs()), 209 toAffineExpr(-1))); 210 if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) 211 return affineBinaryOp(mlir::AffineExprKind::Add, op.lhs(), op.rhs()); 212 if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) 213 return affineBinaryOp(mlir::AffineExprKind::Mul, op.lhs(), op.rhs()); 214 if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) 215 return affineBinaryOp(mlir::AffineExprKind::Mod, op.lhs(), op.rhs()); 216 if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) 217 if (auto intConstant = op.value().dyn_cast<IntegerAttr>()) 218 return toAffineExpr(intConstant.getInt()); 219 if (auto blockArg = value.dyn_cast<mlir::BlockArgument>()) { 220 affineArgs.push_back(value); 221 if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || 222 isa<mlir::AffineForOp>(blockArg.getOwner()->getParentOp())) 223 return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; 224 return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; 225 } 226 return {}; 227 } 228 229 void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { 230 auto lhsAffine = toAffineExpr(cmpOp.lhs()); 231 auto rhsAffine = toAffineExpr(cmpOp.rhs()); 232 if (!lhsAffine.hasValue() || !rhsAffine.hasValue()) 233 return; 234 auto constraintPair = constraint( 235 cmpOp.predicate(), rhsAffine.getValue() - lhsAffine.getValue()); 236 if (!constraintPair) 237 return; 238 integerSet = mlir::IntegerSet::get(dimCount, symCount, 239 {constraintPair.getValue().first}, 240 {constraintPair.getValue().second}); 241 return; 242 } 243 244 llvm::Optional<std::pair<AffineExpr, bool>> 245 constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { 246 switch (predicate) { 247 case mlir::arith::CmpIPredicate::slt: 248 return {std::make_pair(basic - 1, false)}; 249 case mlir::arith::CmpIPredicate::sle: 250 return {std::make_pair(basic, false)}; 251 case mlir::arith::CmpIPredicate::sgt: 252 return {std::make_pair(1 - basic, false)}; 253 case mlir::arith::CmpIPredicate::sge: 254 return {std::make_pair(0 - basic, false)}; 255 case mlir::arith::CmpIPredicate::eq: 256 return {std::make_pair(basic, true)}; 257 default: 258 return {}; 259 } 260 } 261 262 llvm::SmallVector<mlir::Value> affineArgs; 263 llvm::Optional<mlir::IntegerSet> integerSet; 264 mlir::Value firCondition; 265 unsigned symCount{0u}; 266 unsigned dimCount{0u}; 267 }; 268 } // namespace 269 270 namespace { 271 /// Analysis for affine promotion of fir.if 272 struct AffineIfAnalysis { 273 AffineIfAnalysis() = default; 274 275 explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) 276 : legality(analyzeIf(op, afa)) {} 277 278 bool canPromoteToAffine() { return legality; } 279 280 private: 281 bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { 282 if (op.getNumResults() == 0) 283 return true; 284 LLVM_DEBUG(llvm::dbgs() 285 << "AffineIfAnalysis: not promoting as op has results\n";); 286 return false; 287 } 288 289 bool legality{}; 290 }; 291 } // namespace 292 293 AffineIfAnalysis 294 AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { 295 auto it = ifAnalysisMap.find_as(op); 296 if (it == ifAnalysisMap.end()) { 297 LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n"; 298 op.dump();); 299 op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n"); 300 return {}; 301 } 302 return it->getSecond(); 303 } 304 305 /// AffineMap rewriting fir.array_coor operation to affine apply, 306 /// %dim = fir.gendim %lowerBound, %upperBound, %stride 307 /// %a = fir.array_coor %arr(%dim) %i 308 /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> 309 static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, 310 MLIRContext *context) { 311 auto index = mlir::getAffineConstantExpr(0, context); 312 auto accuExtent = mlir::getAffineConstantExpr(1, context); 313 for (unsigned i = 0; i < dimensions; ++i) { 314 mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), 315 lowerBound = mlir::getAffineSymbolExpr(i * 3, context), 316 currentExtent = 317 mlir::getAffineSymbolExpr(i * 3 + 1, context), 318 stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), 319 currentPart = (idx * stride - lowerBound) * accuExtent; 320 index = currentPart + index; 321 accuExtent = accuExtent * currentExtent; 322 } 323 return mlir::AffineMap::get(dimensions, dimensions * 3, index); 324 } 325 326 static Optional<int64_t> constantIntegerLike(const mlir::Value value) { 327 if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) 328 if (auto stepAttr = definition.value().dyn_cast<IntegerAttr>()) 329 return stepAttr.getInt(); 330 return {}; 331 } 332 333 static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { 334 if (auto refType = op.memref().getType().dyn_cast_or_null<ReferenceType>()) { 335 if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) { 336 return seqType.getEleTy(); 337 } 338 } 339 op.emitError( 340 "AffineLoopConversion: array type in coordinate operation not valid\n"); 341 return mlir::Type(); 342 } 343 344 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, 345 SmallVectorImpl<mlir::Value> &indexArgs, 346 mlir::PatternRewriter &rewriter) { 347 auto one = rewriter.create<mlir::arith::ConstantOp>( 348 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 349 auto extents = shape.extents(); 350 for (auto i = extents.begin(); i < extents.end(); i++) { 351 indexArgs.push_back(one); 352 indexArgs.push_back(*i); 353 indexArgs.push_back(one); 354 } 355 } 356 357 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, 358 SmallVectorImpl<mlir::Value> &indexArgs, 359 mlir::PatternRewriter &rewriter) { 360 auto one = rewriter.create<mlir::arith::ConstantOp>( 361 acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); 362 auto extents = shape.pairs(); 363 for (auto i = extents.begin(); i < extents.end();) { 364 indexArgs.push_back(*i++); 365 indexArgs.push_back(*i++); 366 indexArgs.push_back(one); 367 } 368 } 369 370 static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, 371 SmallVectorImpl<mlir::Value> &indexArgs, 372 mlir::PatternRewriter &rewriter) { 373 auto extents = slice.triples(); 374 for (auto i = extents.begin(); i < extents.end();) { 375 indexArgs.push_back(*i++); 376 indexArgs.push_back(*i++); 377 indexArgs.push_back(*i++); 378 } 379 } 380 381 static void populateIndexArgs(fir::ArrayCoorOp acoOp, 382 SmallVectorImpl<mlir::Value> &indexArgs, 383 mlir::PatternRewriter &rewriter) { 384 if (auto shape = acoOp.shape().getDefiningOp<ShapeOp>()) 385 return populateIndexArgs(acoOp, shape, indexArgs, rewriter); 386 if (auto shapeShift = acoOp.shape().getDefiningOp<ShapeShiftOp>()) 387 return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); 388 if (auto slice = acoOp.shape().getDefiningOp<SliceOp>()) 389 return populateIndexArgs(acoOp, slice, indexArgs, rewriter); 390 return; 391 } 392 393 /// Returns affine.apply and fir.convert from array_coor and gendims 394 static std::pair<mlir::AffineApplyOp, fir::ConvertOp> 395 createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { 396 auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); 397 auto affineMap = 398 createArrayIndexAffineMap(acoOp.indices().size(), acoOp.getContext()); 399 SmallVector<mlir::Value> indexArgs; 400 indexArgs.append(acoOp.indices().begin(), acoOp.indices().end()); 401 402 populateIndexArgs(acoOp, indexArgs, rewriter); 403 404 auto affineApply = rewriter.create<mlir::AffineApplyOp>(acoOp.getLoc(), 405 affineMap, indexArgs); 406 auto arrayElementType = coordinateArrayElement(acoOp); 407 auto newType = mlir::MemRefType::get({-1}, arrayElementType); 408 auto arrayConvert = 409 rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, acoOp.memref()); 410 return std::make_pair(affineApply, arrayConvert); 411 } 412 413 static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { 414 rewriter.setInsertionPoint(loadOp); 415 auto affineOps = createAffineOps(loadOp.memref(), rewriter); 416 rewriter.replaceOpWithNewOp<mlir::AffineLoadOp>( 417 loadOp, affineOps.second.getResult(), affineOps.first.getResult()); 418 } 419 420 static void rewriteStore(fir::StoreOp storeOp, 421 mlir::PatternRewriter &rewriter) { 422 rewriter.setInsertionPoint(storeOp); 423 auto affineOps = createAffineOps(storeOp.memref(), rewriter); 424 rewriter.replaceOpWithNewOp<mlir::AffineStoreOp>(storeOp, storeOp.value(), 425 affineOps.second.getResult(), 426 affineOps.first.getResult()); 427 } 428 429 static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { 430 for (auto &bodyOp : block->getOperations()) { 431 if (isa<fir::LoadOp>(bodyOp)) 432 rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); 433 if (isa<fir::StoreOp>(bodyOp)) 434 rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); 435 } 436 } 437 438 namespace { 439 /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to 440 /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir 441 /// loads and stores to affine. 442 class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { 443 public: 444 using OpRewritePattern::OpRewritePattern; 445 AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 446 : OpRewritePattern(context), functionAnalysis(afa) {} 447 448 mlir::LogicalResult 449 matchAndRewrite(fir::DoLoopOp loop, 450 mlir::PatternRewriter &rewriter) const override { 451 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n"; 452 loop.dump();); 453 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = 454 functionAnalysis.getChildLoopAnalysis(loop); 455 auto &loopOps = loop.getBody()->getOperations(); 456 auto loopAndIndex = createAffineFor(loop, rewriter); 457 auto affineFor = loopAndIndex.first; 458 auto inductionVar = loopAndIndex.second; 459 460 rewriter.startRootUpdate(affineFor.getOperation()); 461 affineFor.getBody()->getOperations().splice( 462 std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), 463 std::prev(loopOps.end())); 464 rewriter.finalizeRootUpdate(affineFor.getOperation()); 465 466 rewriter.startRootUpdate(loop.getOperation()); 467 loop.getInductionVar().replaceAllUsesWith(inductionVar); 468 rewriter.finalizeRootUpdate(loop.getOperation()); 469 470 rewriteMemoryOps(affineFor.getBody(), rewriter); 471 472 LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n"; 473 affineFor.dump();); 474 rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); 475 return success(); 476 } 477 478 private: 479 std::pair<mlir::AffineForOp, mlir::Value> 480 createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 481 if (auto constantStep = constantIntegerLike(op.step())) 482 if (constantStep.getValue() > 0) 483 return positiveConstantStep(op, constantStep.getValue(), rewriter); 484 return genericBounds(op, rewriter); 485 } 486 487 // when step for the loop is positive compile time constant 488 std::pair<mlir::AffineForOp, mlir::Value> 489 positiveConstantStep(fir::DoLoopOp op, int64_t step, 490 mlir::PatternRewriter &rewriter) const { 491 auto affineFor = rewriter.create<mlir::AffineForOp>( 492 op.getLoc(), ValueRange(op.lowerBound()), 493 mlir::AffineMap::get(0, 1, 494 mlir::getAffineSymbolExpr(0, op.getContext())), 495 ValueRange(op.upperBound()), 496 mlir::AffineMap::get(0, 1, 497 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 498 step); 499 return std::make_pair(affineFor, affineFor.getInductionVar()); 500 } 501 502 std::pair<mlir::AffineForOp, mlir::Value> 503 genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { 504 auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); 505 auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); 506 auto step = mlir::getAffineSymbolExpr(2, op.getContext()); 507 mlir::AffineMap upperBoundMap = mlir::AffineMap::get( 508 0, 3, (upperBound - lowerBound + step).floorDiv(step)); 509 auto genericUpperBound = rewriter.create<mlir::AffineApplyOp>( 510 op.getLoc(), upperBoundMap, 511 ValueRange({op.lowerBound(), op.upperBound(), op.step()})); 512 auto actualIndexMap = mlir::AffineMap::get( 513 1, 2, 514 (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * 515 mlir::getAffineSymbolExpr(1, op.getContext())); 516 517 auto affineFor = rewriter.create<mlir::AffineForOp>( 518 op.getLoc(), ValueRange(), 519 AffineMap::getConstantMap(0, op.getContext()), 520 genericUpperBound.getResult(), 521 mlir::AffineMap::get(0, 1, 522 1 + mlir::getAffineSymbolExpr(0, op.getContext())), 523 1); 524 rewriter.setInsertionPointToStart(affineFor.getBody()); 525 auto actualIndex = rewriter.create<mlir::AffineApplyOp>( 526 op.getLoc(), actualIndexMap, 527 ValueRange({affineFor.getInductionVar(), op.lowerBound(), op.step()})); 528 return std::make_pair(affineFor, actualIndex.getResult()); 529 } 530 531 AffineFunctionAnalysis &functionAnalysis; 532 }; 533 534 /// Convert `fir.if` to `affine.if`. 535 class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { 536 public: 537 using OpRewritePattern::OpRewritePattern; 538 AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) 539 : OpRewritePattern(context) {} 540 mlir::LogicalResult 541 matchAndRewrite(fir::IfOp op, 542 mlir::PatternRewriter &rewriter) const override { 543 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n"; 544 op.dump();); 545 auto &ifOps = op.thenRegion().front().getOperations(); 546 auto affineCondition = AffineIfCondition(op.condition()); 547 if (!affineCondition.hasIntegerSet()) { 548 LLVM_DEBUG( 549 llvm::dbgs() 550 << "AffineIfConversion: couldn't calculate affine condition\n";); 551 return failure(); 552 } 553 auto affineIf = rewriter.create<mlir::AffineIfOp>( 554 op.getLoc(), affineCondition.getIntegerSet(), 555 affineCondition.getAffineArgs(), !op.elseRegion().empty()); 556 rewriter.startRootUpdate(affineIf); 557 affineIf.getThenBlock()->getOperations().splice( 558 std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), 559 std::prev(ifOps.end())); 560 if (!op.elseRegion().empty()) { 561 auto &otherOps = op.elseRegion().front().getOperations(); 562 affineIf.getElseBlock()->getOperations().splice( 563 std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), 564 std::prev(otherOps.end())); 565 } 566 rewriter.finalizeRootUpdate(affineIf); 567 rewriteMemoryOps(affineIf.getBody(), rewriter); 568 569 LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n"; 570 affineIf.dump();); 571 rewriter.replaceOp(op, affineIf.getOperation()->getResults()); 572 return success(); 573 } 574 }; 575 576 /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases 577 /// where such a promotion is possible. 578 class AffineDialectPromotion 579 : public AffineDialectPromotionBase<AffineDialectPromotion> { 580 public: 581 void runOnFunction() override { 582 583 auto *context = &getContext(); 584 auto function = getFunction(); 585 markAllAnalysesPreserved(); 586 auto functionAnalysis = AffineFunctionAnalysis(function); 587 mlir::OwningRewritePatternList patterns(context); 588 patterns.insert<AffineIfConversion>(context, functionAnalysis); 589 patterns.insert<AffineLoopConversion>(context, functionAnalysis); 590 mlir::ConversionTarget target = *context; 591 target.addLegalDialect< 592 mlir::AffineDialect, FIROpsDialect, mlir::scf::SCFDialect, 593 mlir::arith::ArithmeticDialect, mlir::StandardOpsDialect>(); 594 target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { 595 return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); 596 }); 597 target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( 598 fir::DoLoopOp op) { 599 return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); 600 }); 601 602 LLVM_DEBUG(llvm::dbgs() 603 << "AffineDialectPromotion: running promotion on: \n"; 604 function.print(llvm::dbgs());); 605 // apply the patterns 606 if (mlir::failed(mlir::applyPartialConversion(function, target, 607 std::move(patterns)))) { 608 mlir::emitError(mlir::UnknownLoc::get(context), 609 "error in converting to affine dialect\n"); 610 signalPassFailure(); 611 } 612 } 613 }; 614 } // namespace 615 616 /// Convert FIR loop constructs to the Affine dialect 617 std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { 618 return std::make_unique<AffineDialectPromotion>(); 619 } 620