1 //===-- AffineDemotion.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation is a prototype that demote affine dialects operations 10 // after optimizations to FIR loops operations. 11 // It is used after the AffinePromotion pass. 12 // It is not part of the production pipeline and would need more work in order 13 // to be used in production. 14 // More information can be found in this presentation: 15 // https://slides.com/rajanwalia/deck 16 // 17 //===----------------------------------------------------------------------===// 18 19 #include "PassDetail.h" 20 #include "flang/Optimizer/Dialect/FIRDialect.h" 21 #include "flang/Optimizer/Dialect/FIROps.h" 22 #include "flang/Optimizer/Dialect/FIRType.h" 23 #include "flang/Optimizer/Transforms/Passes.h" 24 #include "mlir/Dialect/Affine/IR/AffineOps.h" 25 #include "mlir/Dialect/Affine/Utils.h" 26 #include "mlir/Dialect/Func/IR/FuncOps.h" 27 #include "mlir/Dialect/MemRef/IR/MemRef.h" 28 #include "mlir/Dialect/SCF/SCF.h" 29 #include "mlir/IR/BuiltinAttributes.h" 30 #include "mlir/IR/IntegerSet.h" 31 #include "mlir/IR/Visitors.h" 32 #include "mlir/Pass/Pass.h" 33 #include "mlir/Transforms/DialectConversion.h" 34 #include "llvm/ADT/DenseMap.h" 35 #include "llvm/ADT/Optional.h" 36 #include "llvm/Support/CommandLine.h" 37 #include "llvm/Support/Debug.h" 38 39 #define DEBUG_TYPE "flang-affine-demotion" 40 41 using namespace fir; 42 using namespace mlir; 43 44 namespace { 45 46 class AffineLoadConversion : public OpConversionPattern<mlir::AffineLoadOp> { 47 public: 48 using OpConversionPattern<mlir::AffineLoadOp>::OpConversionPattern; 49 50 LogicalResult 51 matchAndRewrite(mlir::AffineLoadOp op, OpAdaptor adaptor, 52 ConversionPatternRewriter &rewriter) const override { 53 SmallVector<Value> indices(adaptor.indices()); 54 auto maybeExpandedMap = 55 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 56 if (!maybeExpandedMap) 57 return failure(); 58 59 auto coorOp = rewriter.create<fir::CoordinateOp>( 60 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()), 61 adaptor.memref(), *maybeExpandedMap); 62 63 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult()); 64 return success(); 65 } 66 }; 67 68 class AffineStoreConversion : public OpConversionPattern<mlir::AffineStoreOp> { 69 public: 70 using OpConversionPattern<mlir::AffineStoreOp>::OpConversionPattern; 71 72 LogicalResult 73 matchAndRewrite(mlir::AffineStoreOp op, OpAdaptor adaptor, 74 ConversionPatternRewriter &rewriter) const override { 75 SmallVector<Value> indices(op.indices()); 76 auto maybeExpandedMap = 77 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); 78 if (!maybeExpandedMap) 79 return failure(); 80 81 auto coorOp = rewriter.create<fir::CoordinateOp>( 82 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()), 83 adaptor.memref(), *maybeExpandedMap); 84 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.value(), 85 coorOp.getResult()); 86 return success(); 87 } 88 }; 89 90 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> { 91 public: 92 using OpRewritePattern::OpRewritePattern; 93 mlir::LogicalResult 94 matchAndRewrite(fir::ConvertOp op, 95 mlir::PatternRewriter &rewriter) const override { 96 if (op.getRes().getType().isa<mlir::MemRefType>()) { 97 // due to index calculation moving to affine maps we still need to 98 // add converts for sequence types this has a side effect of losing 99 // some information about arrays with known dimensions by creating: 100 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) -> 101 // !fir.ref<!fir.array<?xi32>> 102 if (auto refTy = op.getValue().getType().dyn_cast<fir::ReferenceType>()) 103 if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) { 104 fir::SequenceType::Shape flatShape = { 105 fir::SequenceType::getUnknownExtent()}; 106 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy()); 107 auto flatTy = fir::ReferenceType::get(flatArrTy); 108 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, 109 op.getValue()); 110 return success(); 111 } 112 rewriter.startRootUpdate(op->getParentOp()); 113 op.getResult().replaceAllUsesWith(op.getValue()); 114 rewriter.finalizeRootUpdate(op->getParentOp()); 115 rewriter.eraseOp(op); 116 } 117 return success(); 118 } 119 }; 120 121 mlir::Type convertMemRef(mlir::MemRefType type) { 122 return fir::SequenceType::get( 123 SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()), 124 type.getElementType()); 125 } 126 127 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> { 128 public: 129 using OpRewritePattern::OpRewritePattern; 130 mlir::LogicalResult 131 matchAndRewrite(memref::AllocOp op, 132 mlir::PatternRewriter &rewriter) const override { 133 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()), 134 op.memref()); 135 return success(); 136 } 137 }; 138 139 class AffineDialectDemotion 140 : public AffineDialectDemotionBase<AffineDialectDemotion> { 141 public: 142 void runOnOperation() override { 143 auto *context = &getContext(); 144 auto function = getOperation(); 145 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n"; 146 function.print(llvm::dbgs());); 147 148 mlir::RewritePatternSet patterns(context); 149 patterns.insert<ConvertConversion>(context); 150 patterns.insert<AffineLoadConversion>(context); 151 patterns.insert<AffineStoreConversion>(context); 152 patterns.insert<StdAllocConversion>(context); 153 mlir::ConversionTarget target(*context); 154 target.addIllegalOp<memref::AllocOp>(); 155 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) { 156 if (op.getRes().getType().isa<mlir::MemRefType>()) 157 return false; 158 return true; 159 }); 160 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect, 161 mlir::arith::ArithmeticDialect, 162 mlir::func::FuncDialect>(); 163 164 if (mlir::failed(mlir::applyPartialConversion(function, target, 165 std::move(patterns)))) { 166 mlir::emitError(mlir::UnknownLoc::get(context), 167 "error in converting affine dialect\n"); 168 signalPassFailure(); 169 } 170 } 171 }; 172 173 } // namespace 174 175 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() { 176 return std::make_unique<AffineDialectDemotion>(); 177 } 178