1 //===-- AffineDemotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that demote affine dialects operations
10 // after optimizations to FIR loops operations.
11 // It is used after the AffinePromotion pass.
12 // It is not part of the production pipeline and would need more work in order
13 // to be used in production.
14 // More information can be found in this presentation:
15 // https://slides.com/rajanwalia/deck
16 //
17 //===----------------------------------------------------------------------===//
18
19 #include "PassDetail.h"
20 #include "flang/Optimizer/Dialect/FIRDialect.h"
21 #include "flang/Optimizer/Dialect/FIROps.h"
22 #include "flang/Optimizer/Dialect/FIRType.h"
23 #include "flang/Optimizer/Transforms/Passes.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h"
25 #include "mlir/Dialect/Affine/Utils.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"
27 #include "mlir/Dialect/MemRef/IR/MemRef.h"
28 #include "mlir/Dialect/SCF/IR/SCF.h"
29 #include "mlir/IR/BuiltinAttributes.h"
30 #include "mlir/IR/IntegerSet.h"
31 #include "mlir/IR/Visitors.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/Optional.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38
39 #define DEBUG_TYPE "flang-affine-demotion"
40
41 using namespace fir;
42 using namespace mlir;
43
44 namespace {
45
46 class AffineLoadConversion : public OpConversionPattern<mlir::AffineLoadOp> {
47 public:
48 using OpConversionPattern<mlir::AffineLoadOp>::OpConversionPattern;
49
50 LogicalResult
matchAndRewrite(mlir::AffineLoadOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const51 matchAndRewrite(mlir::AffineLoadOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter) const override {
53 SmallVector<Value> indices(adaptor.getIndices());
54 auto maybeExpandedMap =
55 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
56 if (!maybeExpandedMap)
57 return failure();
58
59 auto coorOp = rewriter.create<fir::CoordinateOp>(
60 op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
61 adaptor.getMemref(), *maybeExpandedMap);
62
63 rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
64 return success();
65 }
66 };
67
68 class AffineStoreConversion : public OpConversionPattern<mlir::AffineStoreOp> {
69 public:
70 using OpConversionPattern<mlir::AffineStoreOp>::OpConversionPattern;
71
72 LogicalResult
matchAndRewrite(mlir::AffineStoreOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const73 matchAndRewrite(mlir::AffineStoreOp op, OpAdaptor adaptor,
74 ConversionPatternRewriter &rewriter) const override {
75 SmallVector<Value> indices(op.getIndices());
76 auto maybeExpandedMap =
77 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
78 if (!maybeExpandedMap)
79 return failure();
80
81 auto coorOp = rewriter.create<fir::CoordinateOp>(
82 op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
83 adaptor.getMemref(), *maybeExpandedMap);
84 rewriter.replaceOpWithNewOp<fir::StoreOp>(op, adaptor.getValue(),
85 coorOp.getResult());
86 return success();
87 }
88 };
89
90 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
91 public:
92 using OpRewritePattern::OpRewritePattern;
93 mlir::LogicalResult
matchAndRewrite(fir::ConvertOp op,mlir::PatternRewriter & rewriter) const94 matchAndRewrite(fir::ConvertOp op,
95 mlir::PatternRewriter &rewriter) const override {
96 if (op.getRes().getType().isa<mlir::MemRefType>()) {
97 // due to index calculation moving to affine maps we still need to
98 // add converts for sequence types this has a side effect of losing
99 // some information about arrays with known dimensions by creating:
100 // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
101 // !fir.ref<!fir.array<?xi32>>
102 if (auto refTy = op.getValue().getType().dyn_cast<fir::ReferenceType>())
103 if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) {
104 fir::SequenceType::Shape flatShape = {
105 fir::SequenceType::getUnknownExtent()};
106 auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
107 auto flatTy = fir::ReferenceType::get(flatArrTy);
108 rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy,
109 op.getValue());
110 return success();
111 }
112 rewriter.startRootUpdate(op->getParentOp());
113 op.getResult().replaceAllUsesWith(op.getValue());
114 rewriter.finalizeRootUpdate(op->getParentOp());
115 rewriter.eraseOp(op);
116 }
117 return success();
118 }
119 };
120
convertMemRef(mlir::MemRefType type)121 mlir::Type convertMemRef(mlir::MemRefType type) {
122 return fir::SequenceType::get(
123 SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()),
124 type.getElementType());
125 }
126
127 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
128 public:
129 using OpRewritePattern::OpRewritePattern;
130 mlir::LogicalResult
matchAndRewrite(memref::AllocOp op,mlir::PatternRewriter & rewriter) const131 matchAndRewrite(memref::AllocOp op,
132 mlir::PatternRewriter &rewriter) const override {
133 rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
134 op.memref());
135 return success();
136 }
137 };
138
139 class AffineDialectDemotion
140 : public AffineDialectDemotionBase<AffineDialectDemotion> {
141 public:
runOnOperation()142 void runOnOperation() override {
143 auto *context = &getContext();
144 auto function = getOperation();
145 LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
146 function.print(llvm::dbgs()););
147
148 mlir::RewritePatternSet patterns(context);
149 patterns.insert<ConvertConversion>(context);
150 patterns.insert<AffineLoadConversion>(context);
151 patterns.insert<AffineStoreConversion>(context);
152 patterns.insert<StdAllocConversion>(context);
153 mlir::ConversionTarget target(*context);
154 target.addIllegalOp<memref::AllocOp>();
155 target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
156 if (op.getRes().getType().isa<mlir::MemRefType>())
157 return false;
158 return true;
159 });
160 target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
161 mlir::arith::ArithmeticDialect,
162 mlir::func::FuncDialect>();
163
164 if (mlir::failed(mlir::applyPartialConversion(function, target,
165 std::move(patterns)))) {
166 mlir::emitError(mlir::UnknownLoc::get(context),
167 "error in converting affine dialect\n");
168 signalPassFailure();
169 }
170 }
171 };
172
173 } // namespace
174
createAffineDemotionPass()175 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
176 return std::make_unique<AffineDialectDemotion>();
177 }
178