1 //===-- AffineDemotion.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation is a prototype that demote affine dialects operations
10 // after optimizations to FIR loops operations.
11 // It is used after the AffinePromotion pass.
12 // It is not part of the production pipeline and would need more work in order
13 // to be used in production.
14 // More information can be found in this presentation:
15 // https://slides.com/rajanwalia/deck
16 //
17 //===----------------------------------------------------------------------===//
18 
19 #include "PassDetail.h"
20 #include "flang/Optimizer/Dialect/FIRDialect.h"
21 #include "flang/Optimizer/Dialect/FIROps.h"
22 #include "flang/Optimizer/Dialect/FIRType.h"
23 #include "flang/Optimizer/Transforms/Passes.h"
24 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
25 #include "mlir/Dialect/Affine/IR/AffineOps.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/SCF/SCF.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"
29 #include "mlir/IR/BuiltinAttributes.h"
30 #include "mlir/IR/IntegerSet.h"
31 #include "mlir/IR/Visitors.h"
32 #include "mlir/Pass/Pass.h"
33 #include "mlir/Transforms/DialectConversion.h"
34 #include "llvm/ADT/DenseMap.h"
35 #include "llvm/ADT/Optional.h"
36 #include "llvm/Support/CommandLine.h"
37 #include "llvm/Support/Debug.h"
38 
39 #define DEBUG_TYPE "flang-affine-demotion"
40 
41 using namespace fir;
42 
43 namespace {
44 
45 class AffineLoadConversion : public OpRewritePattern<mlir::AffineLoadOp> {
46 public:
47   using OpRewritePattern<mlir::AffineLoadOp>::OpRewritePattern;
48 
49   LogicalResult matchAndRewrite(mlir::AffineLoadOp op,
50                                 PatternRewriter &rewriter) const override {
51     SmallVector<Value> indices(op.getMapOperands());
52     auto maybeExpandedMap =
53         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
54     if (!maybeExpandedMap)
55       return failure();
56 
57     auto coorOp = rewriter.create<fir::CoordinateOp>(
58         op.getLoc(), fir::ReferenceType::get(op.getResult().getType()),
59         op.getMemRef(), *maybeExpandedMap);
60 
61     rewriter.replaceOpWithNewOp<fir::LoadOp>(op, coorOp.getResult());
62     return success();
63   }
64 };
65 
66 class AffineStoreConversion : public OpRewritePattern<mlir::AffineStoreOp> {
67 public:
68   using OpRewritePattern<mlir::AffineStoreOp>::OpRewritePattern;
69 
70   LogicalResult matchAndRewrite(mlir::AffineStoreOp op,
71                                 PatternRewriter &rewriter) const override {
72     SmallVector<Value> indices(op.getMapOperands());
73     auto maybeExpandedMap =
74         expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices);
75     if (!maybeExpandedMap)
76       return failure();
77 
78     auto coorOp = rewriter.create<fir::CoordinateOp>(
79         op.getLoc(), fir::ReferenceType::get(op.getValueToStore().getType()),
80         op.getMemRef(), *maybeExpandedMap);
81     rewriter.replaceOpWithNewOp<fir::StoreOp>(op, op.getValueToStore(),
82                                               coorOp.getResult());
83     return success();
84   }
85 };
86 
87 class ConvertConversion : public mlir::OpRewritePattern<fir::ConvertOp> {
88 public:
89   using OpRewritePattern::OpRewritePattern;
90   mlir::LogicalResult
91   matchAndRewrite(fir::ConvertOp op,
92                   mlir::PatternRewriter &rewriter) const override {
93     if (op.res().getType().isa<mlir::MemRefType>()) {
94       // due to index calculation moving to affine maps we still need to
95       // add converts for sequence types this has a side effect of losing
96       // some information about arrays with known dimensions by creating:
97       // fir.convert %arg0 : (!fir.ref<!fir.array<5xi32>>) ->
98       // !fir.ref<!fir.array<?xi32>>
99       if (auto refTy = op.value().getType().dyn_cast<fir::ReferenceType>())
100         if (auto arrTy = refTy.getEleTy().dyn_cast<fir::SequenceType>()) {
101           fir::SequenceType::Shape flatShape = {
102               fir::SequenceType::getUnknownExtent()};
103           auto flatArrTy = fir::SequenceType::get(flatShape, arrTy.getEleTy());
104           auto flatTy = fir::ReferenceType::get(flatArrTy);
105           rewriter.replaceOpWithNewOp<fir::ConvertOp>(op, flatTy, op.value());
106           return success();
107         }
108       rewriter.startRootUpdate(op->getParentOp());
109       op.getResult().replaceAllUsesWith(op.value());
110       rewriter.finalizeRootUpdate(op->getParentOp());
111       rewriter.eraseOp(op);
112     }
113     return success();
114   }
115 };
116 
117 mlir::Type convertMemRef(mlir::MemRefType type) {
118   return fir::SequenceType::get(
119       SmallVector<int64_t>(type.getShape().begin(), type.getShape().end()),
120       type.getElementType());
121 }
122 
123 class StdAllocConversion : public mlir::OpRewritePattern<memref::AllocOp> {
124 public:
125   using OpRewritePattern::OpRewritePattern;
126   mlir::LogicalResult
127   matchAndRewrite(memref::AllocOp op,
128                   mlir::PatternRewriter &rewriter) const override {
129     rewriter.replaceOpWithNewOp<fir::AllocaOp>(op, convertMemRef(op.getType()),
130                                                op.memref());
131     return success();
132   }
133 };
134 
135 class AffineDialectDemotion
136     : public AffineDialectDemotionBase<AffineDialectDemotion> {
137 public:
138   void runOnFunction() override {
139     auto *context = &getContext();
140     auto function = getFunction();
141     LLVM_DEBUG(llvm::dbgs() << "AffineDemotion: running on function:\n";
142                function.print(llvm::dbgs()););
143 
144     mlir::OwningRewritePatternList patterns(context);
145     patterns.insert<ConvertConversion>(context);
146     patterns.insert<AffineLoadConversion>(context);
147     patterns.insert<AffineStoreConversion>(context);
148     patterns.insert<StdAllocConversion>(context);
149     mlir::ConversionTarget target(*context);
150     target.addIllegalOp<memref::AllocOp>();
151     target.addDynamicallyLegalOp<fir::ConvertOp>([](fir::ConvertOp op) {
152       if (op.res().getType().isa<mlir::MemRefType>())
153         return false;
154       return true;
155     });
156     target.addLegalDialect<FIROpsDialect, mlir::scf::SCFDialect,
157                            mlir::arith::ArithmeticDialect,
158                            mlir::StandardOpsDialect>();
159 
160     if (mlir::failed(mlir::applyPartialConversion(function, target,
161                                                   std::move(patterns)))) {
162       mlir::emitError(mlir::UnknownLoc::get(context),
163                       "error in converting affine dialect\n");
164       signalPassFailure();
165     }
166   }
167 };
168 
169 } // namespace
170 
171 std::unique_ptr<mlir::Pass> fir::createAffineDemotionPass() {
172   return std::make_unique<AffineDialectDemotion>();
173 }
174