1 //===-- TargetRewrite.cpp -------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target rewrite: rewriting of ops to make target-specific lowerings manifest. 10 // LLVM expects different lowering idioms to be used for distinct target 11 // triples. These distinctions are handled by this pass. 12 // 13 // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "Target.h" 19 #include "flang/Lower/Todo.h" 20 #include "flang/Optimizer/CodeGen/CodeGen.h" 21 #include "flang/Optimizer/Dialect/FIRDialect.h" 22 #include "flang/Optimizer/Dialect/FIROps.h" 23 #include "flang/Optimizer/Dialect/FIRType.h" 24 #include "flang/Optimizer/Support/FIRContext.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/Debug.h" 29 30 using namespace fir; 31 32 #define DEBUG_TYPE "flang-target-rewrite" 33 34 namespace { 35 36 /// Fixups for updating a FuncOp's arguments and return values. 37 struct FixupTy { 38 enum class Codes { CharPair, Trailing }; 39 40 FixupTy(Codes code, std::size_t index, std::size_t second = 0) 41 : code{code}, index{index}, second{second} {} 42 FixupTy(Codes code, std::size_t index, 43 std::function<void(mlir::FuncOp)> &&finalizer) 44 : code{code}, index{index}, finalizer{finalizer} {} 45 FixupTy(Codes code, std::size_t index, std::size_t second, 46 std::function<void(mlir::FuncOp)> &&finalizer) 47 : code{code}, index{index}, second{second}, finalizer{finalizer} {} 48 49 Codes code; 50 std::size_t index; 51 std::size_t second{}; 52 llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; 53 }; // namespace 54 55 /// Target-specific rewriting of the FIR. This is a prerequisite pass to code 56 /// generation that traverses the FIR and modifies types and operations to a 57 /// form that is appropriate for the specific target. LLVM IR has specific 58 /// idioms that are used for distinct target processor and ABI combinations. 59 class TargetRewrite : public TargetRewriteBase<TargetRewrite> { 60 public: 61 TargetRewrite(const TargetRewriteOptions &options) { 62 noCharacterConversion = options.noCharacterConversion; 63 } 64 65 void runOnOperation() override final { 66 auto &context = getContext(); 67 mlir::OpBuilder rewriter(&context); 68 69 auto mod = getModule(); 70 if (!forcedTargetTriple.empty()) { 71 setTargetTriple(mod, forcedTargetTriple); 72 } 73 74 auto specifics = CodeGenSpecifics::get(getOperation().getContext(), 75 getTargetTriple(getOperation()), 76 getKindMapping(getOperation())); 77 setMembers(specifics.get(), &rewriter); 78 79 // Perform type conversion on signatures and call sites. 80 if (mlir::failed(convertTypes(mod))) { 81 mlir::emitError(mlir::UnknownLoc::get(&context), 82 "error in converting types to target abi"); 83 signalPassFailure(); 84 } 85 86 // Convert ops in target-specific patterns. 87 mod.walk([&](mlir::Operation *op) { 88 if (auto call = dyn_cast<fir::CallOp>(op)) { 89 if (!hasPortableSignature(call.getFunctionType())) 90 convertCallOp(call); 91 } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { 92 if (!hasPortableSignature(dispatch.getFunctionType())) 93 convertCallOp(dispatch); 94 } 95 }); 96 97 clearMembers(); 98 } 99 100 mlir::ModuleOp getModule() { return getOperation(); } 101 102 // Convert fir.call and fir.dispatch Ops. 103 template <typename A> 104 void convertCallOp(A callOp) { 105 auto fnTy = callOp.getFunctionType(); 106 auto loc = callOp.getLoc(); 107 rewriter->setInsertionPoint(callOp); 108 llvm::SmallVector<mlir::Type> newResTys; 109 llvm::SmallVector<mlir::Type> newInTys; 110 llvm::SmallVector<mlir::Value> newOpers; 111 112 // If the call is indirect, the first argument must still be the function 113 // to call. 114 int dropFront = 0; 115 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 116 if (!callOp.callee().hasValue()) { 117 newInTys.push_back(fnTy.getInput(0)); 118 newOpers.push_back(callOp.getOperand(0)); 119 dropFront = 1; 120 } 121 } 122 123 // Determine the rewrite function, `wrap`, for the result value. 124 llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; 125 if (fnTy.getResults().size() == 1) { 126 mlir::Type ty = fnTy.getResult(0); 127 newResTys.push_back(ty); 128 } else if (fnTy.getResults().size() > 1) { 129 TODO(loc, "multiple results not supported yet"); 130 } 131 132 llvm::SmallVector<mlir::Type> trailingInTys; 133 llvm::SmallVector<mlir::Value> trailingOpers; 134 for (auto e : llvm::enumerate( 135 llvm::zip(fnTy.getInputs().drop_front(dropFront), 136 callOp.getOperands().drop_front(dropFront)))) { 137 mlir::Type ty = std::get<0>(e.value()); 138 mlir::Value oper = std::get<1>(e.value()); 139 unsigned index = e.index(); 140 llvm::TypeSwitch<mlir::Type>(ty) 141 .template Case<BoxCharType>([&](BoxCharType boxTy) { 142 bool sret; 143 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 144 sret = callOp.callee() && 145 functionArgIsSRet(index, 146 getModule().lookupSymbol<mlir::FuncOp>( 147 *callOp.callee())); 148 } else { 149 // TODO: dispatch case; how do we put arguments on a call? 150 // We cannot put both an sret and the dispatch object first. 151 sret = false; 152 TODO(loc, "dispatch + sret not supported yet"); 153 } 154 auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); 155 auto unbox = 156 rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), 157 std::get<mlir::Type>(m[1]), oper); 158 // unboxed CHARACTER arguments 159 for (auto e : llvm::enumerate(m)) { 160 unsigned idx = e.index(); 161 auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); 162 auto argTy = std::get<mlir::Type>(e.value()); 163 if (attr.isAppend()) { 164 trailingInTys.push_back(argTy); 165 trailingOpers.push_back(unbox.getResult(idx)); 166 } else { 167 newInTys.push_back(argTy); 168 newOpers.push_back(unbox.getResult(idx)); 169 } 170 } 171 }) 172 .Default([&](mlir::Type ty) { 173 newInTys.push_back(ty); 174 newOpers.push_back(oper); 175 }); 176 } 177 newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); 178 newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); 179 if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { 180 fir::CallOp newCall; 181 if (callOp.callee().hasValue()) { 182 newCall = rewriter->create<A>(loc, callOp.callee().getValue(), 183 newResTys, newOpers); 184 } else { 185 // Force new type on the input operand. 186 newOpers[0].setType(mlir::FunctionType::get( 187 callOp.getContext(), 188 mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); 189 newCall = rewriter->create<A>(loc, newResTys, newOpers); 190 } 191 LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); 192 if (wrap.hasValue()) 193 replaceOp(callOp, (*wrap)(newCall.getOperation())); 194 else 195 replaceOp(callOp, newCall.getResults()); 196 } else { 197 // A is fir::DispatchOp 198 TODO(loc, "dispatch not implemented"); 199 } 200 } 201 /// Convert the type signatures on all the functions present in the module. 202 /// As the type signature is being changed, this must also update the 203 /// function itself to use any new arguments, etc. 204 mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { 205 for (auto fn : mod.getOps<mlir::FuncOp>()) 206 convertSignature(fn); 207 return mlir::success(); 208 } 209 210 /// If the signature does not need any special target-specific converions, 211 /// then it is considered portable for any target, and this function will 212 /// return `true`. Otherwise, the signature is not portable and `false` is 213 /// returned. 214 bool hasPortableSignature(mlir::Type signature) { 215 assert(signature.isa<mlir::FunctionType>()); 216 auto func = signature.dyn_cast<mlir::FunctionType>(); 217 for (auto ty : func.getResults()) 218 if ((ty.isa<BoxCharType>() && !noCharacterConversion)) { 219 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 220 return false; 221 } 222 for (auto ty : func.getInputs()) 223 if ((ty.isa<BoxCharType>() && !noCharacterConversion)) { 224 LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); 225 return false; 226 } 227 return true; 228 } 229 230 /// Rewrite the signatures and body of the `FuncOp`s in the module for 231 /// the immediately subsequent target code gen. 232 void convertSignature(mlir::FuncOp func) { 233 auto funcTy = func.getType().cast<mlir::FunctionType>(); 234 if (hasPortableSignature(funcTy)) 235 return; 236 llvm::SmallVector<mlir::Type> newResTys; 237 llvm::SmallVector<mlir::Type> newInTys; 238 llvm::SmallVector<FixupTy> fixups; 239 240 // Convert return value(s) 241 for (auto ty : funcTy.getResults()) 242 newResTys.push_back(ty); 243 244 // Convert arguments 245 llvm::SmallVector<mlir::Type> trailingTys; 246 for (auto e : llvm::enumerate(funcTy.getInputs())) { 247 auto ty = e.value(); 248 unsigned index = e.index(); 249 llvm::TypeSwitch<mlir::Type>(ty) 250 .Case<BoxCharType>([&](BoxCharType boxTy) { 251 if (noCharacterConversion) { 252 newInTys.push_back(boxTy); 253 } else { 254 // Convert a CHARACTER argument type. This can involve separating 255 // the pointer and the LEN into two arguments and moving the LEN 256 // argument to the end of the arg list. 257 bool sret = functionArgIsSRet(index, func); 258 for (auto e : llvm::enumerate(specifics->boxcharArgumentType( 259 boxTy.getEleTy(), sret))) { 260 auto &tup = e.value(); 261 auto index = e.index(); 262 auto attr = std::get<CodeGenSpecifics::Attributes>(tup); 263 auto argTy = std::get<mlir::Type>(tup); 264 if (attr.isAppend()) { 265 trailingTys.push_back(argTy); 266 } else { 267 if (sret) { 268 fixups.emplace_back(FixupTy::Codes::CharPair, 269 newInTys.size(), index); 270 } else { 271 fixups.emplace_back(FixupTy::Codes::Trailing, 272 newInTys.size(), trailingTys.size()); 273 } 274 newInTys.push_back(argTy); 275 } 276 } 277 } 278 }) 279 .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); 280 } 281 282 if (!func.empty()) { 283 // If the function has a body, then apply the fixups to the arguments and 284 // return ops as required. These fixups are done in place. 285 auto loc = func.getLoc(); 286 const auto fixupSize = fixups.size(); 287 const auto oldArgTys = func.getType().getInputs(); 288 int offset = 0; 289 for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { 290 const auto &fixup = fixups[i]; 291 switch (fixup.code) { 292 case FixupTy::Codes::CharPair: { 293 // The FIR boxchar argument has been split into a pair of distinct 294 // arguments that are in juxtaposition to each other. 295 auto newArg = 296 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 297 if (fixup.second == 1) { 298 rewriter->setInsertionPointToStart(&func.front()); 299 auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; 300 auto box = rewriter->create<EmboxCharOp>( 301 loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); 302 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 303 func.front().eraseArgument(fixup.index + 1); 304 offset++; 305 } 306 } break; 307 case FixupTy::Codes::Trailing: { 308 // The FIR argument has been split into a pair of distinct arguments. 309 // The first part of the pair appears in the original argument 310 // position. The second part of the pair is appended after all the 311 // original arguments. (Boxchar arguments.) 312 auto newBufArg = 313 func.front().insertArgument(fixup.index, newInTys[fixup.index]); 314 auto newLenArg = func.front().addArgument(trailingTys[fixup.second]); 315 auto boxTy = oldArgTys[fixup.index - offset]; 316 rewriter->setInsertionPointToStart(&func.front()); 317 auto box = 318 rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); 319 func.getArgument(fixup.index + 1).replaceAllUsesWith(box); 320 func.front().eraseArgument(fixup.index + 1); 321 } break; 322 } 323 } 324 } 325 326 // Set the new type and finalize the arguments, etc. 327 newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); 328 auto newFuncTy = 329 mlir::FunctionType::get(func.getContext(), newInTys, newResTys); 330 LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); 331 func.setType(newFuncTy); 332 333 for (auto &fixup : fixups) 334 if (fixup.finalizer) 335 (*fixup.finalizer)(func); 336 } 337 338 inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { 339 if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) 340 return true; 341 return false; 342 } 343 344 private: 345 // Replace `op` and remove it. 346 void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { 347 op->replaceAllUsesWith(newValues); 348 op->dropAllReferences(); 349 op->erase(); 350 } 351 352 inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { 353 specifics = s; 354 rewriter = r; 355 } 356 357 inline void clearMembers() { setMembers(nullptr, nullptr); } 358 359 CodeGenSpecifics *specifics{}; 360 mlir::OpBuilder *rewriter; 361 }; // namespace 362 } // namespace 363 364 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 365 fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { 366 return std::make_unique<TargetRewrite>(options); 367 } 368