1 //===-- AffineDemotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that demote affine dialects operations 10 // after optimizations to FIR loops operations. 11 // It is used after the AffinePromotion pass. 12 // It is not part of the production pipeline and would need more work in order 13 // to be used in production. 14 // More information can be found in this presentation: 15 // https://slides.com/rajanwalia/deck 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "PassDetail.h" 20 #include "flang/Optimizer/Dialect/FIRDialect.h" 21 #include "flang/Optimizer/Dialect/FIROps.h" 22 #include "flang/Optimizer/Dialect/FIRType.h" 23 #include "flang/Optimizer/Transforms/Passes.h" 24 #include "mlir/Dialect/Affine/IR/AffineOps.h" 25 #include "mlir/Dialect/Affine/Utils.h" 26 #include "mlir/Dialect/MemRef/IR/MemRef.h" 27 #include "mlir/Dialect/SCF/SCF.h" 28 #include "mlir/Dialect/StandardOps/IR/Ops.h" 29 #include "mlir/IR/BuiltinAttributes.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Visitors.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "llvm/ADT/DenseMap.h" 35 #include "llvm/ADT/Optional.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/Debug.h" 38 39 #define DEBUG_TYPE "flang-affine-demotion" 40 41 using namespace fir; 42 43 namespace { 44 45 class AffineLoadConversion : public OpConversionPattern<mlir::AffineLoadOp> { 46 public: 47 using OpConversionPattern<mlir::AffineLoadOp>::OpConversionPattern; 48 49 LogicalResult 50 matchAndRewrite(mlir::AffineLoadOp op, OpAdaptor adaptor, 51 ConversionPatternRewriter &rewriter) const override { 52 SmallVector<Value> indices(adaptor.indices()); 53 auto maybeExpandedMap = 54 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 55 if (!maybeExpandedMap) 56 return failure(); 57 58 auto coorOp = rewriter.create<fir::CoordinateOp>( 59 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), 60 adaptor.memref(), *maybeExpandedMap); 61 62 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); 63 return success(); 64 } 65 }; 66 67 class AffineStoreConversion : public OpConversionPattern<mlir::AffineStoreOp> { 68 public: 69 using OpConversionPattern<mlir::AffineStoreOp>::OpConversionPattern; 70 71 LogicalResult 72 matchAndRewrite(mlir::AffineStoreOp op, OpAdaptor adaptor, 73 ConversionPatternRewriter &rewriter) const override { 74 SmallVector<Value> indices(op.indices()); 75 auto maybeExpandedMap = 76 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 77 if (!maybeExpandedMap) 78 return failure(); 79 80 auto coorOp = rewriter.create<fir::CoordinateOp>( 81 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), 82 adaptor.memref(), *maybeExpandedMap); 83 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.value(), 84 coorOp.getResult()); 85 return success(); 86 } 87 }; 88 89 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { 90 public: 91 using OpRewritePattern::OpRewritePattern; 92 mlir::LogicalResult 93 matchAndRewrite(fir::ConvertOp op, 94 mlir::PatternRewriter &rewriter) const override { 95 if (op.getRes().getType().isa<mlir::MemRefType>()) { 96 // due to index calculation moving to affine maps we still need to 97 // add converts for sequence types this has a side effect of losing 98 // some information about arrays with known dimensions by creating: 99 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> 100 // !fir.ref<!fir.array<?xi32>> 101 if (auto refTy = op.getValue().getType().dyn_cast<fir::ReferenceType>()) 102 if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) { 103 fir::SequenceType::Shape flatShape = { 104 fir::SequenceType::getUnknownExtent()}; 105 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); 106 auto flatTy = fir::ReferenceType::get(flatArrTy); 107 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, 108 op.getValue()); 109 return success(); 110 } 111 rewriter.startRootUpdate(op->getParentOp()); 112 op.getResult().replaceAllUsesWith(op.getValue()); 113 rewriter.finalizeRootUpdate(op->getParentOp()); 114 rewriter.eraseOp(op); 115 } 116 return success(); 117 } 118 }; 119 120 mlir::Type convertMemRef(mlir::MemRefType type) { 121 return fir::SequenceType::get( 122 SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()), 123 type.getElementType()); 124 } 125 126 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { 127 public: 128 using OpRewritePattern::OpRewritePattern; 129 mlir::LogicalResult 130 matchAndRewrite(memref::AllocOp op, 131 mlir::PatternRewriter &rewriter) const override { 132 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), 133 op.memref()); 134 return success(); 135 } 136 }; 137 138 class AffineDialectDemotion 139 : public AffineDialectDemotionBase<AffineDialectDemotion> { 140 public: 141 void runOnOperation() override { 142 auto *context = &getContext(); 143 auto function = getOperation(); 144 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; 145 function.print(llvm::dbgs());); 146 147 mlir::RewritePatternSet patterns(context); 148 patterns.insert<ConvertConversion>(context); 149 patterns.insert<AffineLoadConversion>(context); 150 patterns.insert<AffineStoreConversion>(context); 151 patterns.insert<StdAllocConversion>(context); 152 mlir::ConversionTarget target(*context); 153 target.addIllegalOp<memref::AllocOp>(); 154 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { 155 if (op.getRes().getType().isa<mlir::MemRefType>()) 156 return false; 157 return true; 158 }); 159 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 160 mlir::arith::ArithmeticDialect, 161 mlir::StandardOpsDialect>(); 162 163 if (mlir::failed(mlir::applyPartialConversion(function, target, 164 std::move(patterns)))) { 165 mlir::emitError(mlir::UnknownLoc::get(context), 166 "error in converting affine dialect\n"); 167 signalPassFailure(); 168 } 169 } 170 }; 171 172 } // namespace 173 174 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { 175 return std::make_unique<AffineDialectDemotion>(); 176 } 177