1 //===- AbstractResult.cpp - Conversion of Abstract Function Result --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "flang/Optimizer/Builder/Todo.h" 11 #include "flang/Optimizer/Dialect/FIRDialect.h" 12 #include "flang/Optimizer/Dialect/FIROps.h" 13 #include "flang/Optimizer/Dialect/FIRType.h" 14 #include "flang/Optimizer/Transforms/Passes.h" 15 #include "mlir/Dialect/Func/IR/FuncOps.h" 16 #include "mlir/IR/Diagnostics.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 #include "mlir/Transforms/Passes.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 22 #define DEBUG_TYPE "flang-abstract-result-opt" 23 24 namespace fir { 25 namespace { 26 27 static mlir::Type getResultArgumentType(mlir::Type resultType, 28 bool shouldBoxResult) { 29 return llvm::TypeSwitch<mlir::Type, mlir::Type>(resultType) 30 .Case<fir::SequenceType, fir::RecordType>( 31 [&](mlir::Type type) -> mlir::Type { 32 if (shouldBoxResult) 33 return fir::BoxType::get(type); 34 return fir::ReferenceType::get(type); 35 }) 36 .Case<fir::BoxType>([](mlir::Type type) -> mlir::Type { 37 return fir::ReferenceType::get(type); 38 }) 39 .Default([](mlir::Type) -> mlir::Type { 40 llvm_unreachable("bad abstract result type"); 41 }); 42 } 43 44 static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, 45 bool shouldBoxResult) { 46 auto resultType = funcTy.getResult(0); 47 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 48 llvm::SmallVector<mlir::Type> newInputTypes = {argTy}; 49 newInputTypes.append(funcTy.getInputs().begin(), funcTy.getInputs().end()); 50 return mlir::FunctionType::get(funcTy.getContext(), newInputTypes, 51 /*resultTypes=*/{}); 52 } 53 54 static bool mustEmboxResult(mlir::Type resultType, bool shouldBoxResult) { 55 return resultType.isa<fir::SequenceType, fir::RecordType>() && 56 shouldBoxResult; 57 } 58 59 class CallOpConversion : public mlir::OpRewritePattern<fir::CallOp> { 60 public: 61 using OpRewritePattern::OpRewritePattern; 62 CallOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 63 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 64 mlir::LogicalResult 65 matchAndRewrite(fir::CallOp callOp, 66 mlir::PatternRewriter &rewriter) const override { 67 auto loc = callOp.getLoc(); 68 auto result = callOp->getResult(0); 69 if (!result.hasOneUse()) { 70 mlir::emitError(loc, 71 "calls with abstract result must have exactly one user"); 72 return mlir::failure(); 73 } 74 auto saveResult = 75 mlir::dyn_cast<fir::SaveResultOp>(result.use_begin().getUser()); 76 if (!saveResult) { 77 mlir::emitError( 78 loc, "calls with abstract result must be used in fir.save_result"); 79 return mlir::failure(); 80 } 81 auto argType = getResultArgumentType(result.getType(), shouldBoxResult); 82 auto buffer = saveResult.getMemref(); 83 mlir::Value arg = buffer; 84 if (mustEmboxResult(result.getType(), shouldBoxResult)) 85 arg = rewriter.create<fir::EmboxOp>( 86 loc, argType, buffer, saveResult.getShape(), /*slice*/ mlir::Value{}, 87 saveResult.getTypeparams()); 88 89 llvm::SmallVector<mlir::Type> newResultTypes; 90 if (callOp.getCallee()) { 91 llvm::SmallVector<mlir::Value> newOperands = {arg}; 92 newOperands.append(callOp.getOperands().begin(), 93 callOp.getOperands().end()); 94 rewriter.create<fir::CallOp>(loc, *callOp.getCallee(), newResultTypes, 95 newOperands); 96 } else { 97 // Indirect calls. 98 llvm::SmallVector<mlir::Type> newInputTypes = {argType}; 99 for (auto operand : callOp.getOperands().drop_front()) 100 newInputTypes.push_back(operand.getType()); 101 auto funTy = mlir::FunctionType::get(callOp.getContext(), newInputTypes, 102 newResultTypes); 103 104 llvm::SmallVector<mlir::Value> newOperands; 105 newOperands.push_back( 106 rewriter.create<fir::ConvertOp>(loc, funTy, callOp.getOperand(0))); 107 newOperands.push_back(arg); 108 newOperands.append(callOp.getOperands().begin() + 1, 109 callOp.getOperands().end()); 110 rewriter.create<fir::CallOp>(loc, mlir::SymbolRefAttr{}, newResultTypes, 111 newOperands); 112 } 113 callOp->dropAllReferences(); 114 rewriter.eraseOp(callOp); 115 return mlir::success(); 116 } 117 118 private: 119 bool shouldBoxResult; 120 }; 121 122 class SaveResultOpConversion 123 : public mlir::OpRewritePattern<fir::SaveResultOp> { 124 public: 125 using OpRewritePattern::OpRewritePattern; 126 SaveResultOpConversion(mlir::MLIRContext *context) 127 : OpRewritePattern(context) {} 128 mlir::LogicalResult 129 matchAndRewrite(fir::SaveResultOp op, 130 mlir::PatternRewriter &rewriter) const override { 131 rewriter.eraseOp(op); 132 return mlir::success(); 133 } 134 }; 135 136 class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { 137 public: 138 using OpRewritePattern::OpRewritePattern; 139 ReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) 140 : OpRewritePattern(context), newArg{newArg} {} 141 mlir::LogicalResult 142 matchAndRewrite(mlir::func::ReturnOp ret, 143 mlir::PatternRewriter &rewriter) const override { 144 rewriter.setInsertionPoint(ret); 145 auto returnedValue = ret.getOperand(0); 146 bool replacedStorage = false; 147 if (auto *op = returnedValue.getDefiningOp()) 148 if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) { 149 auto resultStorage = load.getMemref(); 150 load.getMemref().replaceAllUsesWith(newArg); 151 replacedStorage = true; 152 if (auto *alloc = resultStorage.getDefiningOp()) 153 if (alloc->use_empty()) 154 rewriter.eraseOp(alloc); 155 } 156 // The result storage may have been optimized out by a memory to 157 // register pass, this is possible for fir.box results, or fir.record 158 // with no length parameters. Simply store the result in the result storage. 159 // at the return point. 160 if (!replacedStorage) 161 rewriter.create<fir::StoreOp>(ret.getLoc(), returnedValue, newArg); 162 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); 163 return mlir::success(); 164 } 165 166 private: 167 mlir::Value newArg; 168 }; 169 170 class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> { 171 public: 172 using OpRewritePattern::OpRewritePattern; 173 AddrOfOpConversion(mlir::MLIRContext *context, bool shouldBoxResult) 174 : OpRewritePattern(context), shouldBoxResult{shouldBoxResult} {} 175 mlir::LogicalResult 176 matchAndRewrite(fir::AddrOfOp addrOf, 177 mlir::PatternRewriter &rewriter) const override { 178 auto oldFuncTy = addrOf.getType().cast<mlir::FunctionType>(); 179 auto newFuncTy = getNewFunctionType(oldFuncTy, shouldBoxResult); 180 auto newAddrOf = rewriter.create<fir::AddrOfOp>(addrOf.getLoc(), newFuncTy, 181 addrOf.getSymbol()); 182 // Rather than converting all op a function pointer might transit through 183 // (e.g calls, stores, loads, converts...), cast new type to the abstract 184 // type. A conversion will be added when calling indirect calls of abstract 185 // types. 186 rewriter.replaceOpWithNewOp<fir::ConvertOp>(addrOf, oldFuncTy, newAddrOf); 187 return mlir::success(); 188 } 189 190 private: 191 bool shouldBoxResult; 192 }; 193 194 class AbstractResultOpt : public fir::AbstractResultOptBase<AbstractResultOpt> { 195 public: 196 void runOnOperation() override { 197 auto *context = &getContext(); 198 auto func = getOperation(); 199 auto loc = func.getLoc(); 200 mlir::RewritePatternSet patterns(context); 201 mlir::ConversionTarget target = *context; 202 const bool shouldBoxResult = passResultAsBox.getValue(); 203 204 // Convert function type itself if it has an abstract result 205 auto funcTy = func.getFunctionType().cast<mlir::FunctionType>(); 206 if (hasAbstractResult(funcTy)) { 207 func.setType(getNewFunctionType(funcTy, shouldBoxResult)); 208 unsigned zero = 0; 209 if (!func.empty()) { 210 // Insert new argument 211 mlir::OpBuilder rewriter(context); 212 auto resultType = funcTy.getResult(0); 213 auto argTy = getResultArgumentType(resultType, shouldBoxResult); 214 mlir::Value newArg = func.front().insertArgument(zero, argTy, loc); 215 if (mustEmboxResult(resultType, shouldBoxResult)) { 216 auto bufferType = fir::ReferenceType::get(resultType); 217 rewriter.setInsertionPointToStart(&func.front()); 218 newArg = rewriter.create<fir::BoxAddrOp>(loc, bufferType, newArg); 219 } 220 patterns.insert<ReturnOpConversion>(context, newArg); 221 target.addDynamicallyLegalOp<mlir::func::ReturnOp>( 222 [](mlir::func::ReturnOp ret) { return ret.operands().empty(); }); 223 } 224 } 225 226 if (func.empty()) 227 return; 228 229 // Convert the calls and, if needed, the ReturnOp in the function body. 230 target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithmeticDialect, 231 mlir::func::FuncDialect>(); 232 target.addIllegalOp<fir::SaveResultOp>(); 233 target.addDynamicallyLegalOp<fir::CallOp>([](fir::CallOp call) { 234 return !hasAbstractResult(call.getFunctionType()); 235 }); 236 target.addDynamicallyLegalOp<fir::AddrOfOp>([](fir::AddrOfOp addrOf) { 237 if (auto funTy = addrOf.getType().dyn_cast<mlir::FunctionType>()) 238 return !hasAbstractResult(funTy); 239 return true; 240 }); 241 target.addDynamicallyLegalOp<fir::DispatchOp>([](fir::DispatchOp dispatch) { 242 if (dispatch->getNumResults() != 1) 243 return true; 244 auto resultType = dispatch->getResult(0).getType(); 245 if (resultType.isa<fir::SequenceType, fir::BoxType, fir::RecordType>()) { 246 TODO(dispatch.getLoc(), "dispatchOp with abstract results"); 247 return false; 248 } 249 return true; 250 }); 251 252 patterns.insert<CallOpConversion>(context, shouldBoxResult); 253 patterns.insert<SaveResultOpConversion>(context); 254 patterns.insert<AddrOfOpConversion>(context, shouldBoxResult); 255 if (mlir::failed( 256 mlir::applyPartialConversion(func, target, std::move(patterns)))) { 257 mlir::emitError(func.getLoc(), "error in converting abstract results\n"); 258 signalPassFailure(); 259 } 260 } 261 }; 262 } // end anonymous namespace 263 } // namespace fir 264 265 std::unique_ptr<mlir::Pass> fir::createAbstractResultOptPass() { 266 return std::make_unique<AbstractResultOpt>(); 267 } 268