1b0eef1eeSValentin Clement //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2b0eef1eeSValentin Clement //
3b0eef1eeSValentin Clement // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b0eef1eeSValentin Clement // See https://llvm.org/LICENSE.txt for license information.
5b0eef1eeSValentin Clement // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b0eef1eeSValentin Clement //
7b0eef1eeSValentin Clement //===----------------------------------------------------------------------===//
8b0eef1eeSValentin Clement 
9b0eef1eeSValentin Clement #include "PassDetail.h"
109de831aaSJean Perier #include "flang/Optimizer/Builder/Todo.h"
11b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIRDialect.h"
12b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIROps.h"
13b0eef1eeSValentin Clement #include "flang/Optimizer/Dialect/FIRType.h"
14b0eef1eeSValentin Clement #include "flang/Optimizer/Transforms/Passes.h"
1523aa5a74SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
16b0eef1eeSValentin Clement #include "mlir/IR/Diagnostics.h"
17b0eef1eeSValentin Clement #include "mlir/Pass/Pass.h"
18b0eef1eeSValentin Clement #include "mlir/Transforms/DialectConversion.h"
19b0eef1eeSValentin Clement #include "mlir/Transforms/Passes.h"
20b0eef1eeSValentin Clement #include "llvm/ADT/TypeSwitch.h"
21b0eef1eeSValentin Clement 
22b0eef1eeSValentin Clement #define DEBUG_TYPE "flang-abstract-result-opt"
23b0eef1eeSValentin Clement 
24b0eef1eeSValentin Clement namespace fir {
25b0eef1eeSValentin Clement namespace {
26b0eef1eeSValentin Clement 
getResultArgumentType(mlir::Type resultType,bool shouldBoxResult)27b0eef1eeSValentin Clement static mlir::Type getResultArgumentType(mlir::Type resultType,
28*ea1cdb58SDaniil Dudkin                                         bool shouldBoxResult) {
29b0eef1eeSValentin Clement   return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
30b0eef1eeSValentin Clement       .Case<fir::SequenceType, fir::RecordType>(
31b0eef1eeSValentin Clement           [&](mlir::Type type) -> mlir::Type {
32*ea1cdb58SDaniil Dudkin             if (shouldBoxResult)
33b0eef1eeSValentin Clement               return fir::BoxType::get(type);
34b0eef1eeSValentin Clement             return fir::ReferenceType::get(type);
35b0eef1eeSValentin Clement           })
36b0eef1eeSValentin Clement       .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
37b0eef1eeSValentin Clement         return fir::ReferenceType::get(type);
38b0eef1eeSValentin Clement       })
39b0eef1eeSValentin Clement       .Default([](mlir::Type) -> mlir::Type {
40b0eef1eeSValentin Clement         llvm_unreachable("bad abstract result type");
41b0eef1eeSValentin Clement       });
42b0eef1eeSValentin Clement }
43b0eef1eeSValentin Clement 
getNewFunctionType(mlir::FunctionType funcTy,bool shouldBoxResult)44*ea1cdb58SDaniil Dudkin static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
45*ea1cdb58SDaniil Dudkin                                              bool shouldBoxResult) {
46b0eef1eeSValentin Clement   auto resultType = funcTy.getResult(0);
47*ea1cdb58SDaniil Dudkin   auto argTy = getResultArgumentType(resultType, shouldBoxResult);
48b0eef1eeSValentin Clement   llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
49b0eef1eeSValentin Clement   newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
50b0eef1eeSValentin Clement   return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
51b0eef1eeSValentin Clement                                  /*resultTypes=*/{});
52b0eef1eeSValentin Clement }
53b0eef1eeSValentin Clement 
mustEmboxResult(mlir::Type resultType,bool shouldBoxResult)54*ea1cdb58SDaniil Dudkin static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
55b0eef1eeSValentin Clement   return resultType.isa<fir::SequenceType, fir::RecordType>() &&
56*ea1cdb58SDaniil Dudkin          shouldBoxResult;
57b0eef1eeSValentin Clement }
58b0eef1eeSValentin Clement 
59b0eef1eeSValentin Clement class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
60b0eef1eeSValentin Clement public:
61b0eef1eeSValentin Clement   using OpRewritePattern::OpRewritePattern;
CallOpConversion(mlir::MLIRContext * context,bool shouldBoxResult)62*ea1cdb58SDaniil Dudkin   CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
63*ea1cdb58SDaniil Dudkin       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
64b0eef1eeSValentin Clement   mlir::LogicalResult
matchAndRewrite(fir::CallOp callOp,mlir::PatternRewriter & rewriter) const65b0eef1eeSValentin Clement   matchAndRewrite(fir::CallOp callOp,
66b0eef1eeSValentin Clement                   mlir::PatternRewriter &rewriter) const override {
67b0eef1eeSValentin Clement     auto loc = callOp.getLoc();
68b0eef1eeSValentin Clement     auto result = callOp->getResult(0);
69b0eef1eeSValentin Clement     if (!result.hasOneUse()) {
70b0eef1eeSValentin Clement       mlir::emitError(loc,
71b0eef1eeSValentin Clement                       "calls with abstract result must have exactly one user");
72b0eef1eeSValentin Clement       return mlir::failure();
73b0eef1eeSValentin Clement     }
74b0eef1eeSValentin Clement     auto saveResult =
75b0eef1eeSValentin Clement         mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
76b0eef1eeSValentin Clement     if (!saveResult) {
77b0eef1eeSValentin Clement       mlir::emitError(
78b0eef1eeSValentin Clement           loc, "calls with abstract result must be used in fir.save_result");
79b0eef1eeSValentin Clement       return mlir::failure();
80b0eef1eeSValentin Clement     }
81*ea1cdb58SDaniil Dudkin     auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
82149ad3d5SShraiysh Vaishay     auto buffer = saveResult.getMemref();
83b0eef1eeSValentin Clement     mlir::Value arg = buffer;
84*ea1cdb58SDaniil Dudkin     if (mustEmboxResult(result.getType(), shouldBoxResult))
85b0eef1eeSValentin Clement       arg = rewriter.create<fir::EmboxOp>(
86149ad3d5SShraiysh Vaishay           loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
87149ad3d5SShraiysh Vaishay           saveResult.getTypeparams());
88b0eef1eeSValentin Clement 
89b0eef1eeSValentin Clement     llvm::SmallVector<mlir::Type> newResultTypes;
90149ad3d5SShraiysh Vaishay     if (callOp.getCallee()) {
91b0eef1eeSValentin Clement       llvm::SmallVector<mlir::Value> newOperands = {arg};
92b0eef1eeSValentin Clement       newOperands.append(callOp.getOperands().begin(),
93b0eef1eeSValentin Clement                          callOp.getOperands().end());
94009ab172SKazu Hirata       rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes,
95009ab172SKazu Hirata                                    newOperands);
96b0eef1eeSValentin Clement     } else {
97b0eef1eeSValentin Clement       // Indirect calls.
98b0eef1eeSValentin Clement       llvm::SmallVector<mlir::Type> newInputTypes = {argType};
99b0eef1eeSValentin Clement       for (auto operand : callOp.getOperands().drop_front())
100b0eef1eeSValentin Clement         newInputTypes.push_back(operand.getType());
101b0eef1eeSValentin Clement       auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
102b0eef1eeSValentin Clement                                            newResultTypes);
103b0eef1eeSValentin Clement 
104b0eef1eeSValentin Clement       llvm::SmallVector<mlir::Value> newOperands;
105b0eef1eeSValentin Clement       newOperands.push_back(
106b0eef1eeSValentin Clement           rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
107b0eef1eeSValentin Clement       newOperands.push_back(arg);
108b0eef1eeSValentin Clement       newOperands.append(callOp.getOperands().begin() + 1,
109b0eef1eeSValentin Clement                          callOp.getOperands().end());
110b0eef1eeSValentin Clement       rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
111b0eef1eeSValentin Clement                                    newOperands);
112b0eef1eeSValentin Clement     }
113b0eef1eeSValentin Clement     callOp->dropAllReferences();
114b0eef1eeSValentin Clement     rewriter.eraseOp(callOp);
115b0eef1eeSValentin Clement     return mlir::success();
116b0eef1eeSValentin Clement   }
117b0eef1eeSValentin Clement 
118b0eef1eeSValentin Clement private:
119*ea1cdb58SDaniil Dudkin   bool shouldBoxResult;
120b0eef1eeSValentin Clement };
121b0eef1eeSValentin Clement 
122b0eef1eeSValentin Clement class SaveResultOpConversion
123b0eef1eeSValentin Clement     : public mlir::OpRewritePattern<fir::SaveResultOp> {
124b0eef1eeSValentin Clement public:
125b0eef1eeSValentin Clement   using OpRewritePattern::OpRewritePattern;
SaveResultOpConversion(mlir::MLIRContext * context)126b0eef1eeSValentin Clement   SaveResultOpConversion(mlir::MLIRContext *context)
127b0eef1eeSValentin Clement       : OpRewritePattern(context) {}
128b0eef1eeSValentin Clement   mlir::LogicalResult
matchAndRewrite(fir::SaveResultOp op,mlir::PatternRewriter & rewriter) const129b0eef1eeSValentin Clement   matchAndRewrite(fir::SaveResultOp op,
130b0eef1eeSValentin Clement                   mlir::PatternRewriter &rewriter) const override {
131b0eef1eeSValentin Clement     rewriter.eraseOp(op);
132b0eef1eeSValentin Clement     return mlir::success();
133b0eef1eeSValentin Clement   }
134b0eef1eeSValentin Clement };
135b0eef1eeSValentin Clement 
13623aa5a74SRiver Riddle class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
137b0eef1eeSValentin Clement public:
138b0eef1eeSValentin Clement   using OpRewritePattern::OpRewritePattern;
ReturnOpConversion(mlir::MLIRContext * context,mlir::Value newArg)139*ea1cdb58SDaniil Dudkin   ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
140*ea1cdb58SDaniil Dudkin       : OpRewritePattern(context), newArg{newArg} {}
141b0eef1eeSValentin Clement   mlir::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,mlir::PatternRewriter & rewriter) const14223aa5a74SRiver Riddle   matchAndRewrite(mlir::func::ReturnOp ret,
143b0eef1eeSValentin Clement                   mlir::PatternRewriter &rewriter) const override {
144b0eef1eeSValentin Clement     rewriter.setInsertionPoint(ret);
145b0eef1eeSValentin Clement     auto returnedValue = ret.getOperand(0);
146b0eef1eeSValentin Clement     bool replacedStorage = false;
147b0eef1eeSValentin Clement     if (auto *op = returnedValue.getDefiningOp())
148b0eef1eeSValentin Clement       if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
149149ad3d5SShraiysh Vaishay         auto resultStorage = load.getMemref();
150*ea1cdb58SDaniil Dudkin         load.getMemref().replaceAllUsesWith(newArg);
151b0eef1eeSValentin Clement         replacedStorage = true;
152b0eef1eeSValentin Clement         if (auto *alloc = resultStorage.getDefiningOp())
153b0eef1eeSValentin Clement           if (alloc->use_empty())
154b0eef1eeSValentin Clement             rewriter.eraseOp(alloc);
155b0eef1eeSValentin Clement       }
156b0eef1eeSValentin Clement     // The result storage may have been optimized out by a memory to
157b0eef1eeSValentin Clement     // register pass, this is possible for fir.box results, or fir.record
158b0eef1eeSValentin Clement     // with no length parameters. Simply store the result in the result storage.
159b0eef1eeSValentin Clement     // at the return point.
160b0eef1eeSValentin Clement     if (!replacedStorage)
161*ea1cdb58SDaniil Dudkin       rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
16223aa5a74SRiver Riddle     rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
163b0eef1eeSValentin Clement     return mlir::success();
164b0eef1eeSValentin Clement   }
165b0eef1eeSValentin Clement 
166b0eef1eeSValentin Clement private:
167*ea1cdb58SDaniil Dudkin   mlir::Value newArg;
168b0eef1eeSValentin Clement };
169b0eef1eeSValentin Clement 
170b0eef1eeSValentin Clement class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
171b0eef1eeSValentin Clement public:
172b0eef1eeSValentin Clement   using OpRewritePattern::OpRewritePattern;
AddrOfOpConversion(mlir::MLIRContext * context,bool shouldBoxResult)173*ea1cdb58SDaniil Dudkin   AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
174*ea1cdb58SDaniil Dudkin       : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
175b0eef1eeSValentin Clement   mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addrOf,mlir::PatternRewriter & rewriter) const176b0eef1eeSValentin Clement   matchAndRewrite(fir::AddrOfOp addrOf,
177b0eef1eeSValentin Clement                   mlir::PatternRewriter &rewriter) const override {
178b0eef1eeSValentin Clement     auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
179*ea1cdb58SDaniil Dudkin     auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
180b0eef1eeSValentin Clement     auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
181149ad3d5SShraiysh Vaishay                                                     addrOf.getSymbol());
182b0eef1eeSValentin Clement     // Rather than converting all op a function pointer might transit through
183b0eef1eeSValentin Clement     // (e.g calls, stores, loads, converts...), cast new type to the abstract
184b0eef1eeSValentin Clement     // type. A conversion will be added when calling indirect calls of abstract
185b0eef1eeSValentin Clement     // types.
186b0eef1eeSValentin Clement     rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
187b0eef1eeSValentin Clement     return mlir::success();
188b0eef1eeSValentin Clement   }
189b0eef1eeSValentin Clement 
190b0eef1eeSValentin Clement private:
191*ea1cdb58SDaniil Dudkin   bool shouldBoxResult;
192b0eef1eeSValentin Clement };
193b0eef1eeSValentin Clement 
194b0eef1eeSValentin Clement class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
195b0eef1eeSValentin Clement public:
runOnOperation()196b0eef1eeSValentin Clement   void runOnOperation() override {
197b0eef1eeSValentin Clement     auto *context = &getContext();
198b0eef1eeSValentin Clement     auto func = getOperation();
199b0eef1eeSValentin Clement     auto loc = func.getLoc();
2009f85c198SRiver Riddle     mlir::RewritePatternSet patterns(context);
201b0eef1eeSValentin Clement     mlir::ConversionTarget target = *context;
202*ea1cdb58SDaniil Dudkin     const bool shouldBoxResult = passResultAsBox.getValue();
203b0eef1eeSValentin Clement 
204b0eef1eeSValentin Clement     // Convert function type itself if it has an abstract result
2054a3460a7SRiver Riddle     auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
2069de831aaSJean Perier     if (hasAbstractResult(funcTy)) {
207*ea1cdb58SDaniil Dudkin       func.setType(getNewFunctionType(funcTy, shouldBoxResult));
208b0eef1eeSValentin Clement       unsigned zero = 0;
209b0eef1eeSValentin Clement       if (!func.empty()) {
210b0eef1eeSValentin Clement         // Insert new argument
211b0eef1eeSValentin Clement         mlir::OpBuilder rewriter(context);
212b0eef1eeSValentin Clement         auto resultType = funcTy.getResult(0);
213*ea1cdb58SDaniil Dudkin         auto argTy = getResultArgumentType(resultType, shouldBoxResult);
214*ea1cdb58SDaniil Dudkin         mlir::Value newArg = func.front().insertArgument(zero, argTy, loc);
215*ea1cdb58SDaniil Dudkin         if (mustEmboxResult(resultType, shouldBoxResult)) {
216b0eef1eeSValentin Clement           auto bufferType = fir::ReferenceType::get(resultType);
217b0eef1eeSValentin Clement           rewriter.setInsertionPointToStart(&func.front());
218*ea1cdb58SDaniil Dudkin           newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
219b0eef1eeSValentin Clement         }
220*ea1cdb58SDaniil Dudkin         patterns.insert<ReturnOpConversion>(context, newArg);
22123aa5a74SRiver Riddle         target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
22223aa5a74SRiver Riddle             [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
223b0eef1eeSValentin Clement       }
224b0eef1eeSValentin Clement     }
225b0eef1eeSValentin Clement 
226b0eef1eeSValentin Clement     if (func.empty())
227b0eef1eeSValentin Clement       return;
228b0eef1eeSValentin Clement 
229b0eef1eeSValentin Clement     // Convert the calls and, if needed,  the ReturnOp in the function body.
230a54f4eaeSMogball     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
23123aa5a74SRiver Riddle                            mlir::func::FuncDialect>();
232b0eef1eeSValentin Clement     target.addIllegalOp<fir::SaveResultOp>();
233b0eef1eeSValentin Clement     target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
2349de831aaSJean Perier       return !hasAbstractResult(call.getFunctionType());
235b0eef1eeSValentin Clement     });
236b0eef1eeSValentin Clement     target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
237b0eef1eeSValentin Clement       if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
2389de831aaSJean Perier         return !hasAbstractResult(funTy);
239b0eef1eeSValentin Clement       return true;
240b0eef1eeSValentin Clement     });
241b0eef1eeSValentin Clement     target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
242b0eef1eeSValentin Clement       if (dispatch->getNumResults() != 1)
243b0eef1eeSValentin Clement         return true;
244b0eef1eeSValentin Clement       auto resultType = dispatch->getResult(0).getType();
245b0eef1eeSValentin Clement       if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
2469de831aaSJean Perier         TODO(dispatch.getLoc(), "dispatchOp with abstract results");
247b0eef1eeSValentin Clement         return false;
248b0eef1eeSValentin Clement       }
249b0eef1eeSValentin Clement       return true;
250b0eef1eeSValentin Clement     });
251b0eef1eeSValentin Clement 
252*ea1cdb58SDaniil Dudkin     patterns.insert<CallOpConversion>(context, shouldBoxResult);
253b0eef1eeSValentin Clement     patterns.insert<SaveResultOpConversion>(context);
254*ea1cdb58SDaniil Dudkin     patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
255b0eef1eeSValentin Clement     if (mlir::failed(
256b0eef1eeSValentin Clement             mlir::applyPartialConversion(func, target, std::move(patterns)))) {
257b0eef1eeSValentin Clement       mlir::emitError(func.getLoc(), "error in converting abstract results\n");
258b0eef1eeSValentin Clement       signalPassFailure();
259b0eef1eeSValentin Clement     }
260b0eef1eeSValentin Clement   }
261b0eef1eeSValentin Clement };
262b0eef1eeSValentin Clement } // end anonymous namespace
263b0eef1eeSValentin Clement } // namespace fir
264b0eef1eeSValentin Clement 
createAbstractResultOptPass()265b0eef1eeSValentin Clement std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
266b0eef1eeSValentin Clement   return std::make_unique<AbstractResultOpt>();
267b0eef1eeSValentin Clement }
268