1 //===-- AffineDemotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that demote affine dialects operations 10 // after optimizations to FIR loops operations. 11 // It is used after the AffinePromotion pass. 12 // It is not part of the production pipeline and would need more work in order 13 // to be used in production. 14 // More information can be found in this presentation: 15 // https://slides.com/rajanwalia/deck 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "PassDetail.h" 20 #include "flang/Optimizer/Dialect/FIRDialect.h" 21 #include "flang/Optimizer/Dialect/FIROps.h" 22 #include "flang/Optimizer/Dialect/FIRType.h" 23 #include "flang/Optimizer/Transforms/Passes.h" 24 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 25 #include "mlir/Dialect/Affine/IR/AffineOps.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/SCF.h" 28 #include "mlir/Dialect/StandardOps/IR/Ops.h" 29 #include "mlir/IR/BuiltinAttributes.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Visitors.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "llvm/ADT/DenseMap.h" 35 #include "llvm/ADT/Optional.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/Debug.h" 38 39 #define DEBUG_TYPE "flang-affine-demotion" 40 41 using namespace fir; 42 43 namespace { 44 45 class AffineLoadConversion : public OpRewritePattern<mlir::AffineLoadOp> { 46 public: 47 using OpRewritePattern<mlir::AffineLoadOp>::OpRewritePattern; 48 49 LogicalResult matchAndRewrite(mlir::AffineLoadOp op, 50 PatternRewriter &rewriter) const override { 51 SmallVector<Value> indices(op.getMapOperands()); 52 auto maybeExpandedMap = 53 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 54 if (!maybeExpandedMap) 55 return failure(); 56 57 auto coorOp = rewriter.create<fir::CoordinateOp>( 58 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), 59 op.getMemRef(), *maybeExpandedMap); 60 61 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); 62 return success(); 63 } 64 }; 65 66 class AffineStoreConversion : public OpRewritePattern<mlir::AffineStoreOp> { 67 public: 68 using OpRewritePattern<mlir::AffineStoreOp>::OpRewritePattern; 69 70 LogicalResult matchAndRewrite(mlir::AffineStoreOp op, 71 PatternRewriter &rewriter) const override { 72 SmallVector<Value> indices(op.getMapOperands()); 73 auto maybeExpandedMap = 74 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 75 if (!maybeExpandedMap) 76 return failure(); 77 78 auto coorOp = rewriter.create<fir::CoordinateOp>( 79 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), 80 op.getMemRef(), *maybeExpandedMap); 81 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValueToStore(), 82 coorOp.getResult()); 83 return success(); 84 } 85 }; 86 87 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { 88 public: 89 using OpRewritePattern::OpRewritePattern; 90 mlir::LogicalResult 91 matchAndRewrite(fir::ConvertOp op, 92 mlir::PatternRewriter &rewriter) const override { 93 if (op.res().getType().isa<mlir::MemRefType>()) { 94 // due to index calculation moving to affine maps we still need to 95 // add converts for sequence types this has a side effect of losing 96 // some information about arrays with known dimensions by creating: 97 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> 98 // !fir.ref<!fir.array<?xi32>> 99 if (auto refTy = op.value().getType().dyn_cast<fir::ReferenceType>()) 100 if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) { 101 fir::SequenceType::Shape flatShape = { 102 fir::SequenceType::getUnknownExtent()}; 103 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); 104 auto flatTy = fir::ReferenceType::get(flatArrTy); 105 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, op.value()); 106 return success(); 107 } 108 rewriter.startRootUpdate(op->getParentOp()); 109 op.getResult().replaceAllUsesWith(op.value()); 110 rewriter.finalizeRootUpdate(op->getParentOp()); 111 rewriter.eraseOp(op); 112 } 113 return success(); 114 } 115 }; 116 117 mlir::Type convertMemRef(mlir::MemRefType type) { 118 return fir::SequenceType::get( 119 SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()), 120 type.getElementType()); 121 } 122 123 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { 124 public: 125 using OpRewritePattern::OpRewritePattern; 126 mlir::LogicalResult 127 matchAndRewrite(memref::AllocOp op, 128 mlir::PatternRewriter &rewriter) const override { 129 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), 130 op.memref()); 131 return success(); 132 } 133 }; 134 135 class AffineDialectDemotion 136 : public AffineDialectDemotionBase<AffineDialectDemotion> { 137 public: 138 void runOnFunction() override { 139 auto *context = &getContext(); 140 auto function = getFunction(); 141 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; 142 function.print(llvm::dbgs());); 143 144 mlir::OwningRewritePatternList patterns(context); 145 patterns.insert<ConvertConversion>(context); 146 patterns.insert<AffineLoadConversion>(context); 147 patterns.insert<AffineStoreConversion>(context); 148 patterns.insert<StdAllocConversion>(context); 149 mlir::ConversionTarget target(*context); 150 target.addIllegalOp<memref::AllocOp>(); 151 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { 152 if (op.res().getType().isa<mlir::MemRefType>()) 153 return false; 154 return true; 155 }); 156 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 157 mlir::arith::ArithmeticDialect, 158 mlir::StandardOpsDialect>(); 159 160 if (mlir::failed(mlir::applyPartialConversion(function, target, 161 std::move(patterns)))) { 162 mlir::emitError(mlir::UnknownLoc::get(context), 163 "error in converting affine dialect\n"); 164 signalPassFailure(); 165 } 166 } 167 }; 168 169 } // namespace 170 171 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { 172 return std::make_unique<AffineDialectDemotion>(); 173 } 174