1 //===-- AffineDemotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that demote affine dialects operations 10 // after optimizations to FIR loops operations. 11 // It is used after the AffinePromotion pass. 12 // It is not part of the production pipeline and would need more work in order 13 // to be used in production. 14 // More information can be found in this presentation: 15 // https://slides.com/rajanwalia/deck 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "PassDetail.h" 20 #include "flang/Optimizer/Dialect/FIRDialect.h" 21 #include "flang/Optimizer/Dialect/FIROps.h" 22 #include "flang/Optimizer/Dialect/FIRType.h" 23 #include "flang/Optimizer/Transforms/Passes.h" 24 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" 25 #include "mlir/Dialect/Affine/IR/AffineOps.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/SCF.h" 28 #include "mlir/Dialect/StandardOps/IR/Ops.h" 29 #include "mlir/IR/BuiltinAttributes.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Visitors.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "llvm/ADT/DenseMap.h" 35 #include "llvm/ADT/Optional.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/Debug.h" 38 39 #define DEBUG_TYPE "flang-affine-demotion" 40 41 using namespace fir; 42 43 namespace { 44 45 class AffineLoadConversion : public OpConversionPattern<mlir::AffineLoadOp> { 46 public: 47 using OpConversionPattern<mlir::AffineLoadOp>::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(mlir::AffineLoadOp op, OpAdaptor adaptor, 51 ConversionPatternRewriter &rewriter) const override { 52 SmallVector<Value> indices(adaptor.indices()); 53 auto maybeExpandedMap = 54 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 55 if (!maybeExpandedMap) 56 return failure(); 57 58 auto coorOp = rewriter.create<fir::CoordinateOp>( 59 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), 60 adaptor.memref(), *maybeExpandedMap); 61 62 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); 63 return success(); 64 } 65 }; 66 67 class AffineStoreConversion : public OpConversionPattern<mlir::AffineStoreOp> { 68 public: 69 using OpConversionPattern<mlir::AffineStoreOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(mlir::AffineStoreOp op, OpAdaptor adaptor, 73 ConversionPatternRewriter &rewriter) const override { 74 SmallVector<Value> indices(op.indices()); 75 auto maybeExpandedMap = 76 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 77 if (!maybeExpandedMap) 78 return failure(); 79 80 auto coorOp = rewriter.create<fir::CoordinateOp>( 81 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), 82 adaptor.memref(), *maybeExpandedMap); 83 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.value(), 84 coorOp.getResult()); 85 return success(); 86 } 87 }; 88 89 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { 90 public: 91 using OpRewritePattern::OpRewritePattern; 92 mlir::LogicalResult 93 matchAndRewrite(fir::ConvertOp op, 94 mlir::PatternRewriter &rewriter) const override { 95 if (op.res().getType().isa<mlir::MemRefType>()) { 96 // due to index calculation moving to affine maps we still need to 97 // add converts for sequence types this has a side effect of losing 98 // some information about arrays with known dimensions by creating: 99 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> 100 // !fir.ref<!fir.array<?xi32>> 101 if (auto refTy = op.value().getType().dyn_cast<fir::ReferenceType>()) 102 if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) { 103 fir::SequenceType::Shape flatShape = { 104 fir::SequenceType::getUnknownExtent()}; 105 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); 106 auto flatTy = fir::ReferenceType::get(flatArrTy); 107 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, op.value()); 108 return success(); 109 } 110 rewriter.startRootUpdate(op->getParentOp()); 111 op.getResult().replaceAllUsesWith(op.value()); 112 rewriter.finalizeRootUpdate(op->getParentOp()); 113 rewriter.eraseOp(op); 114 } 115 return success(); 116 } 117 }; 118 119 mlir::Type convertMemRef(mlir::MemRefType type) { 120 return fir::SequenceType::get( 121 SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()), 122 type.getElementType()); 123 } 124 125 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { 126 public: 127 using OpRewritePattern::OpRewritePattern; 128 mlir::LogicalResult 129 matchAndRewrite(memref::AllocOp op, 130 mlir::PatternRewriter &rewriter) const override { 131 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), 132 op.memref()); 133 return success(); 134 } 135 }; 136 137 class AffineDialectDemotion 138 : public AffineDialectDemotionBase<AffineDialectDemotion> { 139 public: 140 void runOnFunction() override { 141 auto *context = &getContext(); 142 auto function = getFunction(); 143 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; 144 function.print(llvm::dbgs());); 145 146 mlir::OwningRewritePatternList patterns(context); 147 patterns.insert<ConvertConversion>(context); 148 patterns.insert<AffineLoadConversion>(context); 149 patterns.insert<AffineStoreConversion>(context); 150 patterns.insert<StdAllocConversion>(context); 151 mlir::ConversionTarget target(*context); 152 target.addIllegalOp<memref::AllocOp>(); 153 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { 154 if (op.res().getType().isa<mlir::MemRefType>()) 155 return false; 156 return true; 157 }); 158 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 159 mlir::arith::ArithmeticDialect, 160 mlir::StandardOpsDialect>(); 161 162 if (mlir::failed(mlir::applyPartialConversion(function, target, 163 std::move(patterns)))) { 164 mlir::emitError(mlir::UnknownLoc::get(context), 165 "error in converting affine dialect\n"); 166 signalPassFailure(); 167 } 168 } 169 }; 170 171 } // namespace 172 173 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { 174 return std::make_unique<AffineDialectDemotion>(); 175 } 176