1 //===-- BoxedProcedure.cpp ------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "flang/Optimizer/Builder/FIRBuilder.h"
11 #include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
12 #include "flang/Optimizer/CodeGen/CodeGen.h"
13 #include "flang/Optimizer/Dialect/FIRDialect.h"
14 #include "flang/Optimizer/Dialect/FIROps.h"
15 #include "flang/Optimizer/Dialect/FIRType.h"
16 #include "flang/Optimizer/Support/FIRContext.h"
17 #include "flang/Optimizer/Support/FatalError.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Pass/Pass.h"
20 #include "mlir/Transforms/DialectConversion.h"
21
22 #define DEBUG_TYPE "flang-procedure-pointer"
23
24 using namespace fir;
25
26 namespace {
27 /// Options to the procedure pointer pass.
28 struct BoxedProcedureOptions {
29 // Lower the boxproc abstraction to function pointers and thunks where
30 // required.
31 bool useThunks = true;
32 };
33
34 /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
35 class BoxprocTypeRewriter : public mlir::TypeConverter {
36 public:
37 using mlir::TypeConverter::convertType;
38
39 /// Does the type \p ty need to be converted?
40 /// Any type that is a `!fir.boxproc` in whole or in part will need to be
41 /// converted to a function type to lower the IR to function pointer form in
42 /// the default implementation performed in this pass. Other implementations
43 /// are possible, so those may convert `!fir.boxproc` to some other type or
44 /// not at all depending on the implementation target's characteristics and
45 /// preference.
needsConversion(mlir::Type ty)46 bool needsConversion(mlir::Type ty) {
47 if (ty.isa<BoxProcType>())
48 return true;
49 if (auto funcTy = ty.dyn_cast<mlir::FunctionType>()) {
50 for (auto t : funcTy.getInputs())
51 if (needsConversion(t))
52 return true;
53 for (auto t : funcTy.getResults())
54 if (needsConversion(t))
55 return true;
56 return false;
57 }
58 if (auto tupleTy = ty.dyn_cast<mlir::TupleType>()) {
59 for (auto t : tupleTy.getTypes())
60 if (needsConversion(t))
61 return true;
62 return false;
63 }
64 if (auto recTy = ty.dyn_cast<RecordType>()) {
65 if (llvm::any_of(visitedTypes,
66 [&](mlir::Type rt) { return rt == recTy; }))
67 return false;
68 bool result = false;
69 visitedTypes.push_back(recTy);
70 for (auto t : recTy.getTypeList()) {
71 if (needsConversion(t.second)) {
72 result = true;
73 break;
74 }
75 }
76 visitedTypes.pop_back();
77 return result;
78 }
79 if (auto boxTy = ty.dyn_cast<BoxType>())
80 return needsConversion(boxTy.getEleTy());
81 if (isa_ref_type(ty))
82 return needsConversion(unwrapRefType(ty));
83 if (auto t = ty.dyn_cast<SequenceType>())
84 return needsConversion(unwrapSequenceType(ty));
85 return false;
86 }
87
BoxprocTypeRewriter(mlir::Location location)88 BoxprocTypeRewriter(mlir::Location location) : loc{location} {
__anonfdb31fde0302(mlir::Type ty) 89 addConversion([](mlir::Type ty) { return ty; });
90 addConversion(
__anonfdb31fde0402(BoxProcType boxproc) 91 [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
__anonfdb31fde0502(mlir::TupleType tupTy) 92 addConversion([&](mlir::TupleType tupTy) {
93 llvm::SmallVector<mlir::Type> memTys;
94 for (auto ty : tupTy.getTypes())
95 memTys.push_back(convertType(ty));
96 return mlir::TupleType::get(tupTy.getContext(), memTys);
97 });
__anonfdb31fde0602(mlir::FunctionType funcTy) 98 addConversion([&](mlir::FunctionType funcTy) {
99 llvm::SmallVector<mlir::Type> inTys;
100 llvm::SmallVector<mlir::Type> resTys;
101 for (auto ty : funcTy.getInputs())
102 inTys.push_back(convertType(ty));
103 for (auto ty : funcTy.getResults())
104 resTys.push_back(convertType(ty));
105 return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
106 });
__anonfdb31fde0702(ReferenceType ty) 107 addConversion([&](ReferenceType ty) {
108 return ReferenceType::get(convertType(ty.getEleTy()));
109 });
__anonfdb31fde0802(PointerType ty) 110 addConversion([&](PointerType ty) {
111 return PointerType::get(convertType(ty.getEleTy()));
112 });
113 addConversion(
__anonfdb31fde0902(HeapType ty) 114 [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
115 addConversion(
__anonfdb31fde0a02(BoxType ty) 116 [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
__anonfdb31fde0b02(SequenceType ty) 117 addConversion([&](SequenceType ty) {
118 // TODO: add ty.getLayoutMap() as needed.
119 return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
120 });
__anonfdb31fde0c02(RecordType ty) 121 addConversion([&](RecordType ty) -> mlir::Type {
122 if (!needsConversion(ty))
123 return ty;
124 // FIR record types can have recursive references, so conversion is a bit
125 // more complex than the other types. This conversion is not needed
126 // presently, so just emit a TODO message. Need to consider the uniqued
127 // name of the record, etc. Also, fir::RecordType::get returns the
128 // existing type being translated. So finalize() will not change it, and
129 // the translation would not do anything. So the type needs to be mutated,
130 // and this might require special care to comply with MLIR infrastructure.
131
132 // TODO: this will be needed to support derived type containing procedure
133 // pointer components.
134 fir::emitFatalError(
135 loc, "not yet implemented: record type with a boxproc type");
136 return RecordType::get(ty.getContext(), "*fixme*");
137 });
138 addArgumentMaterialization(materializeProcedure);
139 addSourceMaterialization(materializeProcedure);
140 addTargetMaterialization(materializeProcedure);
141 }
142
materializeProcedure(mlir::OpBuilder & builder,BoxProcType type,mlir::ValueRange inputs,mlir::Location loc)143 static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
144 BoxProcType type,
145 mlir::ValueRange inputs,
146 mlir::Location loc) {
147 assert(inputs.size() == 1);
148 return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()),
149 inputs[0]);
150 }
151
setLocation(mlir::Location location)152 void setLocation(mlir::Location location) { loc = location; }
153
154 private:
155 llvm::SmallVector<mlir::Type> visitedTypes;
156 mlir::Location loc;
157 };
158
159 /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
160 /// Fortran procedures can be referenced directly through a function pointer.
161 /// However, Fortran has one-level dynamic scoping between a host procedure and
162 /// its internal procedures. This allows internal procedures to directly access
163 /// and modify the state of the host procedure's variables.
164 ///
165 /// There are any number of possible implementations possible.
166 ///
167 /// The implementation used here is to convert `boxproc` values to function
168 /// pointers everywhere. If a `boxproc` value includes a frame pointer to the
169 /// host procedure's data, then a thunk will be created at runtime to capture
170 /// the frame pointer during execution. In LLVM IR, the frame pointer is
171 /// designated with the `nest` attribute. The thunk's address will then be used
172 /// as the call target instead of the original function's address directly.
173 class BoxedProcedurePass : public BoxedProcedurePassBase<BoxedProcedurePass> {
174 public:
BoxedProcedurePass()175 BoxedProcedurePass() { options = {true}; }
BoxedProcedurePass(bool useThunks)176 BoxedProcedurePass(bool useThunks) { options = {useThunks}; }
177
getModule()178 inline mlir::ModuleOp getModule() { return getOperation(); }
179
runOnOperation()180 void runOnOperation() override final {
181 if (options.useThunks) {
182 auto *context = &getContext();
183 mlir::IRRewriter rewriter(context);
184 BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
185 mlir::Dialect *firDialect = context->getLoadedDialect("fir");
186 getModule().walk([&](mlir::Operation *op) {
187 typeConverter.setLocation(op->getLoc());
188 if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
189 auto ty = addr.getVal().getType();
190 if (typeConverter.needsConversion(ty) ||
191 ty.isa<mlir::FunctionType>()) {
192 // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
193 // or function type to be `fir.convert` ops.
194 rewriter.setInsertionPoint(addr);
195 rewriter.replaceOpWithNewOp<ConvertOp>(
196 addr, typeConverter.convertType(addr.getType()), addr.getVal());
197 }
198 } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
199 mlir::FunctionType ty = func.getFunctionType();
200 if (typeConverter.needsConversion(ty)) {
201 rewriter.startRootUpdate(func);
202 auto toTy =
203 typeConverter.convertType(ty).cast<mlir::FunctionType>();
204 if (!func.empty())
205 for (auto e : llvm::enumerate(toTy.getInputs())) {
206 unsigned i = e.index();
207 auto &block = func.front();
208 block.insertArgument(i, e.value(), func.getLoc());
209 block.getArgument(i + 1).replaceAllUsesWith(
210 block.getArgument(i));
211 block.eraseArgument(i + 1);
212 }
213 func.setType(toTy);
214 rewriter.finalizeRootUpdate(func);
215 }
216 } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
217 // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
218 // as required.
219 mlir::Type toTy = embox.getType().cast<BoxProcType>().getEleTy();
220 rewriter.setInsertionPoint(embox);
221 if (embox.getHost()) {
222 // Create the thunk.
223 auto module = embox->getParentOfType<mlir::ModuleOp>();
224 fir::KindMapping kindMap = getKindMapping(module);
225 FirOpBuilder builder(rewriter, kindMap);
226 auto loc = embox.getLoc();
227 mlir::Type i8Ty = builder.getI8Type();
228 mlir::Type i8Ptr = builder.getRefType(i8Ty);
229 mlir::Type buffTy = SequenceType::get({32}, i8Ty);
230 auto buffer = builder.create<AllocaOp>(loc, buffTy);
231 mlir::Value closure =
232 builder.createConvert(loc, i8Ptr, embox.getHost());
233 mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer);
234 mlir::Value func =
235 builder.createConvert(loc, i8Ptr, embox.getFunc());
236 builder.create<fir::CallOp>(
237 loc, factory::getLlvmInitTrampoline(builder),
238 llvm::ArrayRef<mlir::Value>{tramp, func, closure});
239 auto adjustCall = builder.create<fir::CallOp>(
240 loc, factory::getLlvmAdjustTrampoline(builder),
241 llvm::ArrayRef<mlir::Value>{tramp});
242 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
243 adjustCall.getResult(0));
244 } else {
245 // Just forward the function as a pointer.
246 rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
247 embox.getFunc());
248 }
249 } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
250 auto ty = mem.getType();
251 if (typeConverter.needsConversion(ty)) {
252 rewriter.setInsertionPoint(mem);
253 auto toTy = typeConverter.convertType(unwrapRefType(ty));
254 bool isPinned = mem.getPinned();
255 llvm::StringRef uniqName =
256 mem.getUniqName().value_or(llvm::StringRef());
257 llvm::StringRef bindcName =
258 mem.getBindcName().value_or(llvm::StringRef());
259 rewriter.replaceOpWithNewOp<AllocaOp>(
260 mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(),
261 mem.getShape());
262 }
263 } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
264 auto ty = mem.getType();
265 if (typeConverter.needsConversion(ty)) {
266 rewriter.setInsertionPoint(mem);
267 auto toTy = typeConverter.convertType(unwrapRefType(ty));
268 llvm::StringRef uniqName =
269 mem.getUniqName().value_or(llvm::StringRef());
270 llvm::StringRef bindcName =
271 mem.getBindcName().value_or(llvm::StringRef());
272 rewriter.replaceOpWithNewOp<AllocMemOp>(
273 mem, toTy, uniqName, bindcName, mem.getTypeparams(),
274 mem.getShape());
275 }
276 } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
277 auto ty = coor.getType();
278 mlir::Type baseTy = coor.getBaseType();
279 if (typeConverter.needsConversion(ty) ||
280 typeConverter.needsConversion(baseTy)) {
281 rewriter.setInsertionPoint(coor);
282 auto toTy = typeConverter.convertType(ty);
283 auto toBaseTy = typeConverter.convertType(baseTy);
284 rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
285 coor.getCoor(), toBaseTy);
286 }
287 } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
288 auto ty = index.getType();
289 mlir::Type onTy = index.getOnType();
290 if (typeConverter.needsConversion(ty) ||
291 typeConverter.needsConversion(onTy)) {
292 rewriter.setInsertionPoint(index);
293 auto toTy = typeConverter.convertType(ty);
294 auto toOnTy = typeConverter.convertType(onTy);
295 rewriter.replaceOpWithNewOp<FieldIndexOp>(
296 index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
297 }
298 } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
299 auto ty = index.getType();
300 mlir::Type onTy = index.getOnType();
301 if (typeConverter.needsConversion(ty) ||
302 typeConverter.needsConversion(onTy)) {
303 rewriter.setInsertionPoint(index);
304 auto toTy = typeConverter.convertType(ty);
305 auto toOnTy = typeConverter.convertType(onTy);
306 rewriter.replaceOpWithNewOp<LenParamIndexOp>(
307 mem, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
308 }
309 } else if (op->getDialect() == firDialect) {
310 rewriter.startRootUpdate(op);
311 for (auto i : llvm::enumerate(op->getResultTypes()))
312 if (typeConverter.needsConversion(i.value())) {
313 auto toTy = typeConverter.convertType(i.value());
314 op->getResult(i.index()).setType(toTy);
315 }
316 rewriter.finalizeRootUpdate(op);
317 }
318 });
319 }
320 // TODO: any alternative implementation. Note: currently, the default code
321 // gen will not be able to handle boxproc and will give an error.
322 }
323
324 private:
325 BoxedProcedureOptions options;
326 };
327 } // namespace
328
createBoxedProcedurePass()329 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() {
330 return std::make_unique<BoxedProcedurePass>();
331 }
332
createBoxedProcedurePass(bool useThunks)333 std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) {
334 return std::make_unique<BoxedProcedurePass>(useThunks);
335 }
336