1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/Todo.h"
11 #include "flang/Optimizer/Dialect/FIRDialect.h"
12 #include "flang/Optimizer/Dialect/FIROps.h"
13 #include "flang/Optimizer/Dialect/FIRType.h"
14 #include "flang/Optimizer/Transforms/Passes.h"
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/IR/Diagnostics.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 #include "mlir/Transforms/Passes.h"
20 #include "llvm/ADT/TypeSwitch.h"
21
22 #define DEBUG_TYPE "flang-abstract-result-opt"
23
24 namespace fir {
25 namespace {
26
getResultArgumentType(mlir::Type resultType,bool shouldBoxResult)27 static mlir::Type getResultArgumentType(mlir::Type resultType,
28 bool shouldBoxResult) {
29 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType)
30 .Case<fir::SequenceType, fir::RecordType>(
31 [&](mlir::Type type) -> mlir::Type {
32 if (shouldBoxResult)
33 return fir::BoxType::get(type);
34 return fir::ReferenceType::get(type);
35 })
36 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type {
37 return fir::ReferenceType::get(type);
38 })
39 .Default([](mlir::Type) -> mlir::Type {
40 llvm_unreachable("bad abstract result type");
41 });
42 }
43
getNewFunctionType(mlir::FunctionType funcTy,bool shouldBoxResult)44 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
45 bool shouldBoxResult) {
46 auto resultType = funcTy.getResult(0);
47 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
48 llvm::SmallVector<mlir::Type> newInputTypes = {argTy};
49 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end());
50 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes,
51 /*resultTypes=*/{});
52 }
53
mustEmboxResult(mlir::Type resultType,bool shouldBoxResult)54 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) {
55 return resultType.isa<fir::SequenceType, fir::RecordType>() &&
56 shouldBoxResult;
57 }
58
59 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> {
60 public:
61 using OpRewritePattern::OpRewritePattern;
CallOpConversion(mlir::MLIRContext * context,bool shouldBoxResult)62 CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
63 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
64 mlir::LogicalResult
matchAndRewrite(fir::CallOp callOp,mlir::PatternRewriter & rewriter) const65 matchAndRewrite(fir::CallOp callOp,
66 mlir::PatternRewriter &rewriter) const override {
67 auto loc = callOp.getLoc();
68 auto result = callOp->getResult(0);
69 if (!result.hasOneUse()) {
70 mlir::emitError(loc,
71 "calls with abstract result must have exactly one user");
72 return mlir::failure();
73 }
74 auto saveResult =
75 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser());
76 if (!saveResult) {
77 mlir::emitError(
78 loc, "calls with abstract result must be used in fir.save_result");
79 return mlir::failure();
80 }
81 auto argType = getResultArgumentType(result.getType(), shouldBoxResult);
82 auto buffer = saveResult.getMemref();
83 mlir::Value arg = buffer;
84 if (mustEmboxResult(result.getType(), shouldBoxResult))
85 arg = rewriter.create<fir::EmboxOp>(
86 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{},
87 saveResult.getTypeparams());
88
89 llvm::SmallVector<mlir::Type> newResultTypes;
90 if (callOp.getCallee()) {
91 llvm::SmallVector<mlir::Value> newOperands = {arg};
92 newOperands.append(callOp.getOperands().begin(),
93 callOp.getOperands().end());
94 rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes,
95 newOperands);
96 } else {
97 // Indirect calls.
98 llvm::SmallVector<mlir::Type> newInputTypes = {argType};
99 for (auto operand : callOp.getOperands().drop_front())
100 newInputTypes.push_back(operand.getType());
101 auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes,
102 newResultTypes);
103
104 llvm::SmallVector<mlir::Value> newOperands;
105 newOperands.push_back(
106 rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0)));
107 newOperands.push_back(arg);
108 newOperands.append(callOp.getOperands().begin() + 1,
109 callOp.getOperands().end());
110 rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes,
111 newOperands);
112 }
113 callOp->dropAllReferences();
114 rewriter.eraseOp(callOp);
115 return mlir::success();
116 }
117
118 private:
119 bool shouldBoxResult;
120 };
121
122 class SaveResultOpConversion
123 : public mlir::OpRewritePattern<fir::SaveResultOp> {
124 public:
125 using OpRewritePattern::OpRewritePattern;
SaveResultOpConversion(mlir::MLIRContext * context)126 SaveResultOpConversion(mlir::MLIRContext *context)
127 : OpRewritePattern(context) {}
128 mlir::LogicalResult
matchAndRewrite(fir::SaveResultOp op,mlir::PatternRewriter & rewriter) const129 matchAndRewrite(fir::SaveResultOp op,
130 mlir::PatternRewriter &rewriter) const override {
131 rewriter.eraseOp(op);
132 return mlir::success();
133 }
134 };
135
136 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
137 public:
138 using OpRewritePattern::OpRewritePattern;
ReturnOpConversion(mlir::MLIRContext * context,mlir::Value newArg)139 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
140 : OpRewritePattern(context), newArg{newArg} {}
141 mlir::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,mlir::PatternRewriter & rewriter) const142 matchAndRewrite(mlir::func::ReturnOp ret,
143 mlir::PatternRewriter &rewriter) const override {
144 rewriter.setInsertionPoint(ret);
145 auto returnedValue = ret.getOperand(0);
146 bool replacedStorage = false;
147 if (auto *op = returnedValue.getDefiningOp())
148 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
149 auto resultStorage = load.getMemref();
150 load.getMemref().replaceAllUsesWith(newArg);
151 replacedStorage = true;
152 if (auto *alloc = resultStorage.getDefiningOp())
153 if (alloc->use_empty())
154 rewriter.eraseOp(alloc);
155 }
156 // The result storage may have been optimized out by a memory to
157 // register pass, this is possible for fir.box results, or fir.record
158 // with no length parameters. Simply store the result in the result storage.
159 // at the return point.
160 if (!replacedStorage)
161 rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg);
162 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
163 return mlir::success();
164 }
165
166 private:
167 mlir::Value newArg;
168 };
169
170 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
171 public:
172 using OpRewritePattern::OpRewritePattern;
AddrOfOpConversion(mlir::MLIRContext * context,bool shouldBoxResult)173 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult)
174 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {}
175 mlir::LogicalResult
matchAndRewrite(fir::AddrOfOp addrOf,mlir::PatternRewriter & rewriter) const176 matchAndRewrite(fir::AddrOfOp addrOf,
177 mlir::PatternRewriter &rewriter) const override {
178 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>();
179 auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult);
180 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy,
181 addrOf.getSymbol());
182 // Rather than converting all op a function pointer might transit through
183 // (e.g calls, stores, loads, converts...), cast new type to the abstract
184 // type. A conversion will be added when calling indirect calls of abstract
185 // types.
186 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf);
187 return mlir::success();
188 }
189
190 private:
191 bool shouldBoxResult;
192 };
193
194 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> {
195 public:
runOnOperation()196 void runOnOperation() override {
197 auto *context = &getContext();
198 auto func = getOperation();
199 auto loc = func.getLoc();
200 mlir::RewritePatternSet patterns(context);
201 mlir::ConversionTarget target = *context;
202 const bool shouldBoxResult = passResultAsBox.getValue();
203
204 // Convert function type itself if it has an abstract result
205 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>();
206 if (hasAbstractResult(funcTy)) {
207 func.setType(getNewFunctionType(funcTy, shouldBoxResult));
208 unsigned zero = 0;
209 if (!func.empty()) {
210 // Insert new argument
211 mlir::OpBuilder rewriter(context);
212 auto resultType = funcTy.getResult(0);
213 auto argTy = getResultArgumentType(resultType, shouldBoxResult);
214 mlir::Value newArg = func.front().insertArgument(zero, argTy, loc);
215 if (mustEmboxResult(resultType, shouldBoxResult)) {
216 auto bufferType = fir::ReferenceType::get(resultType);
217 rewriter.setInsertionPointToStart(&func.front());
218 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg);
219 }
220 patterns.insert<ReturnOpConversion>(context, newArg);
221 target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
222 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); });
223 }
224 }
225
226 if (func.empty())
227 return;
228
229 // Convert the calls and, if needed, the ReturnOp in the function body.
230 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect,
231 mlir::func::FuncDialect>();
232 target.addIllegalOp<fir::SaveResultOp>();
233 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) {
234 return !hasAbstractResult(call.getFunctionType());
235 });
236 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) {
237 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>())
238 return !hasAbstractResult(funTy);
239 return true;
240 });
241 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) {
242 if (dispatch->getNumResults() != 1)
243 return true;
244 auto resultType = dispatch->getResult(0).getType();
245 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) {
246 TODO(dispatch.getLoc(), "dispatchOp with abstract results");
247 return false;
248 }
249 return true;
250 });
251
252 patterns.insert<CallOpConversion>(context, shouldBoxResult);
253 patterns.insert<SaveResultOpConversion>(context);
254 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult);
255 if (mlir::failed(
256 mlir::applyPartialConversion(func, target, std::move(patterns)))) {
257 mlir::emitError(func.getLoc(), "error in converting abstract results\n");
258 signalPassFailure();
259 }
260 }
261 };
262 } // end anonymous namespace
263 } // namespace fir
264
createAbstractResultOptPass()265 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() {
266 return std::make_unique<AbstractResultOpt>();
267 }
268