1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Transforms/Passes.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "mlir/Transforms/Passes.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 #define DEBUG_TYPE "flang-abstract-result-opt"
23 
24 namespace fir {
25 namespace {
26 
getResultArgumentType(mlir::Type resultType,bool shouldBoxResult)27 static mlir::Type getResultArgumentType(mlir::Type resultType,
28                                         bool shouldBoxResult) {
29   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
30       .Case<fir::SequenceType, fir::RecordType>(
31           [&](mlir::Type type) -> mlir::Type {
32             if (shouldBoxResult)
33               return fir::BoxType::get(type);
34             return fir::ReferenceType::get(type);
35           })
36       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
37         return fir::ReferenceType::get(type);
38       })
39       .Default([](mlir::Type) -> mlir::Type {
40         llvm_unreachable("bad abstract result type");
41       });
42 }
43 
getNewFunctionType(mlir::FunctionType funcTy,bool shouldBoxResult)44 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
45                                              bool shouldBoxResult) {
46   auto resultType = funcTy.getResult(0);
47   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
48   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
49   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
50   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
51                                  /*resultTypes=*/{});
52 }
53 
mustEmboxResult(mlir::Type resultType,bool shouldBoxResult)54 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
55   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
56          shouldBoxResult;
57 }
58 
59 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
60 public:
61   using OpRewritePattern::OpRewritePattern;
CallOpConversion(mlir::MLIRContext * context,bool shouldBoxResult)62   CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
63       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
64   mlir::LogicalResult
matchAndRewrite(fir::CallOp callOp,mlir::PatternRewriter & rewriter) const65   matchAndRewrite(fir::CallOp callOp,
66                   mlir::PatternRewriter &rewriter) const override {
67     auto loc = callOp.getLoc();
68     auto result = callOp->getResult(0);
69     if (!result.hasOneUse()) {
70       mlir::emitError(loc,
71                       "calls with abstract result must have exactly one user");
72       return mlir::failure();
73     }
74     auto saveResult =
75         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
76     if (!saveResult) {
77       mlir::emitError(
78           loc, "calls with abstract result must be used in fir.save_result");
79       return mlir::failure();
80     }
81     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
82     auto buffer = saveResult.getMemref();
83     mlir::Value arg = buffer;
84     if (mustEmboxResult(result.getType(), shouldBoxResult))
85       arg = rewriter.create<fir::EmboxOp>(
86           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
87           saveResult.getTypeparams());
88 
89     llvm::SmallVector<mlir::Type> newResultTypes;
90     if (callOp.getCallee()) {
91       llvm::SmallVector<mlir::Value> newOperands = {arg};
92       newOperands.append(callOp.getOperands().begin(),
93                          callOp.getOperands().end());
94       rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes,
95                                    newOperands);
96     } else {
97       // Indirect calls.
98       llvm::SmallVector<mlir::Type> newInputTypes = {argType};
99       for (auto operand : callOp.getOperands().drop_front())
100         newInputTypes.push_back(operand.getType());
101       auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
102                                            newResultTypes);
103 
104       llvm::SmallVector<mlir::Value> newOperands;
105       newOperands.push_back(
106           rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
107       newOperands.push_back(arg);
108       newOperands.append(callOp.getOperands().begin() + 1,
109                          callOp.getOperands().end());
110       rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
111                                    newOperands);
112     }
113     callOp->dropAllReferences();
114     rewriter.eraseOp(callOp);
115     return mlir::success();
116   }
117 
118 private:
119   bool shouldBoxResult;
120 };
121 
122 class SaveResultOpConversion
123     : public mlir::OpRewritePattern<fir::SaveResultOp> {
124 public:
125   using OpRewritePattern::OpRewritePattern;
SaveResultOpConversion(mlir::MLIRContext * context)126   SaveResultOpConversion(mlir::MLIRContext *context)
127       : OpRewritePattern(context) {}
128   mlir::LogicalResult
matchAndRewrite(fir::SaveResultOp op,mlir::PatternRewriter & rewriter) const129   matchAndRewrite(fir::SaveResultOp op,
130                   mlir::PatternRewriter &rewriter) const override {
131     rewriter.eraseOp(op);
132     return mlir::success();
133   }
134 };
135 
136 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
137 public:
138   using OpRewritePattern::OpRewritePattern;
ReturnOpConversion(mlir::MLIRContext * context,mlir::Value newArg)139   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
140       : OpRewritePattern(context), newArg{newArg} {}
141   mlir::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,mlir::PatternRewriter & rewriter) const142   matchAndRewrite(mlir::func::ReturnOp ret,
143                   mlir::PatternRewriter &rewriter) const override {
144     rewriter.setInsertionPoint(ret);
145     auto returnedValue = ret.getOperand(0);
146     bool replacedStorage = false;
147     if (auto *op = returnedValue.getDefiningOp())
148       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
149         auto resultStorage = load.getMemref();
150         load.getMemref().replaceAllUsesWith(newArg);
151         replacedStorage = true;
152         if (auto *alloc = resultStorage.getDefiningOp())
153           if (alloc->use_empty())
154             rewriter.eraseOp(alloc);
155       }
156     // The result storage may have been optimized out by a memory to
157     // register pass, this is possible for fir.box results, or fir.record
158     // with no length parameters. Simply store the result in the result storage.
159     // at the return point.
160     if (!replacedStorage)
161       rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
162     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
163     return mlir::success();
164   }
165 
166 private:
167   mlir::Value newArg;
168 };
169 
170 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
171 public:
172   using OpRewritePattern::OpRewritePattern;
AddrOfOpConversion(mlir::MLIRContext * context,bool shouldBoxResult)173   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
174       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
175   mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addrOf,mlir::PatternRewriter & rewriter) const176   matchAndRewrite(fir::AddrOfOp addrOf,
177                   mlir::PatternRewriter &rewriter) const override {
178     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
179     auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
180     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
181                                                     addrOf.getSymbol());
182     // Rather than converting all op a function pointer might transit through
183     // (e.g calls, stores, loads, converts...), cast new type to the abstract
184     // type. A conversion will be added when calling indirect calls of abstract
185     // types.
186     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
187     return mlir::success();
188   }
189 
190 private:
191   bool shouldBoxResult;
192 };
193 
194 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
195 public:
runOnOperation()196   void runOnOperation() override {
197     auto *context = &getContext();
198     auto func = getOperation();
199     auto loc = func.getLoc();
200     mlir::RewritePatternSet patterns(context);
201     mlir::ConversionTarget target = *context;
202     const bool shouldBoxResult = passResultAsBox.getValue();
203 
204     // Convert function type itself if it has an abstract result
205     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
206     if (hasAbstractResult(funcTy)) {
207       func.setType(getNewFunctionType(funcTy, shouldBoxResult));
208       unsigned zero = 0;
209       if (!func.empty()) {
210         // Insert new argument
211         mlir::OpBuilder rewriter(context);
212         auto resultType = funcTy.getResult(0);
213         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
214         mlir::Value newArg = func.front().insertArgument(zero, argTy, loc);
215         if (mustEmboxResult(resultType, shouldBoxResult)) {
216           auto bufferType = fir::ReferenceType::get(resultType);
217           rewriter.setInsertionPointToStart(&func.front());
218           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
219         }
220         patterns.insert<ReturnOpConversion>(context, newArg);
221         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
222             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
223       }
224     }
225 
226     if (func.empty())
227       return;
228 
229     // Convert the calls and, if needed,  the ReturnOp in the function body.
230     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
231                            mlir::func::FuncDialect>();
232     target.addIllegalOp<fir::SaveResultOp>();
233     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
234       return !hasAbstractResult(call.getFunctionType());
235     });
236     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
237       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
238         return !hasAbstractResult(funTy);
239       return true;
240     });
241     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
242       if (dispatch->getNumResults() != 1)
243         return true;
244       auto resultType = dispatch->getResult(0).getType();
245       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
246         TODO(dispatch.getLoc(), "dispatchOp with abstract results");
247         return false;
248       }
249       return true;
250     });
251 
252     patterns.insert<CallOpConversion>(context, shouldBoxResult);
253     patterns.insert<SaveResultOpConversion>(context);
254     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
255     if (mlir::failed(
256             mlir::applyPartialConversion(func, target, std::move(patterns)))) {
257       mlir::emitError(func.getLoc(), "error in converting abstract results\n");
258       signalPassFailure();
259     }
260   }
261 };
262 } // end anonymous namespace
263 } // namespace fir
264 
createAbstractResultOptPass()265 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
266   return std::make_unique<AbstractResultOpt>();
267 }
268