1 //===- NormalizeMemRefs.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an interprocedural pass to normalize memrefs to have 10 // identity layout maps. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/Utils.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 20 #include "llvm/ADT/SmallSet.h" 21 #include "llvm/Support/Debug.h" 22 23 #define DEBUG_TYPE "normalize-memrefs" 24 25 using namespace mlir; 26 27 namespace { 28 29 /// All memrefs passed across functions with non-trivial layout maps are 30 /// converted to ones with trivial identity layout ones. 31 /// If all the memref types/uses in a function are normalizable, we treat 32 /// such functions as normalizable. Also, if a normalizable function is known 33 /// to call a non-normalizable function, we treat that function as 34 /// non-normalizable as well. We assume external functions to be normalizable. 35 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> { 36 void runOnOperation() override; 37 void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp); 38 bool areMemRefsNormalizable(func::FuncOp funcOp); 39 void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp); 40 void setCalleesAndCallersNonNormalizable( 41 func::FuncOp funcOp, ModuleOp moduleOp, 42 DenseSet<func::FuncOp> &normalizableFuncs); 43 Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp); 44 }; 45 46 } // namespace 47 48 std::unique_ptr<OperationPass<ModuleOp>> 49 mlir::memref::createNormalizeMemRefsPass() { 50 return std::make_unique<NormalizeMemRefs>(); 51 } 52 53 void NormalizeMemRefs::runOnOperation() { 54 LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n"); 55 ModuleOp moduleOp = getOperation(); 56 // We maintain all normalizable FuncOps in a DenseSet. It is initialized 57 // with all the functions within a module and then functions which are not 58 // normalizable are removed from this set. 59 // TODO: Change this to work on FuncLikeOp once there is an operation 60 // interface for it. 61 DenseSet<func::FuncOp> normalizableFuncs; 62 // Initialize `normalizableFuncs` with all the functions within a module. 63 moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); 64 65 // Traverse through all the functions applying a filter which determines 66 // whether that function is normalizable or not. All callers/callees of 67 // a non-normalizable function will also become non-normalizable even if 68 // they aren't passing any or specific non-normalizable memrefs. So, 69 // functions which calls or get called by a non-normalizable becomes non- 70 // normalizable functions themselves. 71 moduleOp.walk([&](func::FuncOp funcOp) { 72 if (normalizableFuncs.contains(funcOp)) { 73 if (!areMemRefsNormalizable(funcOp)) { 74 LLVM_DEBUG(llvm::dbgs() 75 << "@" << funcOp.getName() 76 << " contains ops that cannot normalize MemRefs\n"); 77 // Since this function is not normalizable, we set all the caller 78 // functions and the callees of this function as not normalizable. 79 // TODO: Drop this conservative assumption in the future. 80 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 81 normalizableFuncs); 82 } 83 } 84 }); 85 86 LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() 87 << " functions\n"); 88 // Those functions which can be normalized are subjected to normalization. 89 for (func::FuncOp &funcOp : normalizableFuncs) 90 normalizeFuncOpMemRefs(funcOp, moduleOp); 91 } 92 93 /// Check whether all the uses of oldMemRef are either dereferencing uses or the 94 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints 95 /// are satisfied will the value become a candidate for replacement. 96 /// TODO: Extend this for DimOps. 97 static bool isMemRefNormalizable(Value::user_range opUsers) { 98 return llvm::all_of(opUsers, [](Operation *op) { 99 return op->hasTrait<OpTrait::MemRefsNormalizable>(); 100 }); 101 } 102 103 /// Set all the calling functions and the callees of the function as not 104 /// normalizable. 105 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( 106 func::FuncOp funcOp, ModuleOp moduleOp, 107 DenseSet<func::FuncOp> &normalizableFuncs) { 108 if (!normalizableFuncs.contains(funcOp)) 109 return; 110 111 LLVM_DEBUG( 112 llvm::dbgs() << "@" << funcOp.getName() 113 << " calls or is called by non-normalizable function\n"); 114 normalizableFuncs.erase(funcOp); 115 // Caller of the function. 116 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp); 117 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 118 // TODO: Extend this for ops that are FunctionOpInterface. This would 119 // require creating an OpInterface for FunctionOpInterface ops. 120 func::FuncOp parentFuncOp = 121 symbolUse.getUser()->getParentOfType<func::FuncOp>(); 122 for (func::FuncOp &funcOp : normalizableFuncs) { 123 if (parentFuncOp == funcOp) { 124 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 125 normalizableFuncs); 126 break; 127 } 128 } 129 } 130 131 // Functions called by this function. 132 funcOp.walk([&](func::CallOp callOp) { 133 StringAttr callee = callOp.getCalleeAttr().getAttr(); 134 for (func::FuncOp &funcOp : normalizableFuncs) { 135 // We compare func::FuncOp and callee's name. 136 if (callee == funcOp.getNameAttr()) { 137 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 138 normalizableFuncs); 139 break; 140 } 141 } 142 }); 143 } 144 145 /// Check whether all the uses of AllocOps, CallOps and function arguments of a 146 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp 147 /// or ReturnOp. Only if these constraints are satisfied will the function 148 /// become a candidate for normalization. We follow a conservative approach here 149 /// wherein even if the non-normalizable memref is not a part of the function's 150 /// argument or return type, we still label the entire function as 151 /// non-normalizable. We assume external functions to be normalizable. 152 bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { 153 // We assume external functions to be normalizable. 154 if (funcOp.isExternal()) 155 return true; 156 157 if (funcOp 158 .walk([&](memref::AllocOp allocOp) -> WalkResult { 159 Value oldMemRef = allocOp.getResult(); 160 if (!isMemRefNormalizable(oldMemRef.getUsers())) 161 return WalkResult::interrupt(); 162 return WalkResult::advance(); 163 }) 164 .wasInterrupted()) 165 return false; 166 167 if (funcOp 168 .walk([&](func::CallOp callOp) -> WalkResult { 169 for (unsigned resIndex : 170 llvm::seq<unsigned>(0, callOp.getNumResults())) { 171 Value oldMemRef = callOp.getResult(resIndex); 172 if (oldMemRef.getType().isa<MemRefType>()) 173 if (!isMemRefNormalizable(oldMemRef.getUsers())) 174 return WalkResult::interrupt(); 175 } 176 return WalkResult::advance(); 177 }) 178 .wasInterrupted()) 179 return false; 180 181 for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 182 BlockArgument oldMemRef = funcOp.getArgument(argIndex); 183 if (oldMemRef.getType().isa<MemRefType>()) 184 if (!isMemRefNormalizable(oldMemRef.getUsers())) 185 return false; 186 } 187 188 return true; 189 } 190 191 /// Fetch the updated argument list and result of the function and update the 192 /// function signature. This updates the function's return type at the caller 193 /// site and in case the return type is a normalized memref then it updates 194 /// the calling function's signature. 195 /// TODO: An update to the calling function signature is required only if the 196 /// returned value is in turn used in ReturnOp of the calling function. 197 void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, 198 ModuleOp moduleOp) { 199 FunctionType functionType = funcOp.getFunctionType(); 200 SmallVector<Type, 4> resultTypes; 201 FunctionType newFuncType; 202 resultTypes = llvm::to_vector<4>(functionType.getResults()); 203 204 // External function's signature was already updated in 205 // 'normalizeFuncOpMemRefs()'. 206 if (!funcOp.isExternal()) { 207 SmallVector<Type, 8> argTypes; 208 for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) 209 argTypes.push_back(argEn.value().getType()); 210 211 // Traverse ReturnOps to check if an update to the return type in the 212 // function signature is required. 213 funcOp.walk([&](func::ReturnOp returnOp) { 214 for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { 215 Type opType = operandEn.value().getType(); 216 MemRefType memrefType = opType.dyn_cast<MemRefType>(); 217 // If type is not memref or if the memref type is same as that in 218 // function's return signature then no update is required. 219 if (!memrefType || memrefType == resultTypes[operandEn.index()]) 220 continue; 221 // Update function's return type signature. 222 // Return type gets normalized either as a result of function argument 223 // normalization, AllocOp normalization or an update made at CallOp. 224 // There can be many call flows inside a function and an update to a 225 // specific ReturnOp has not yet been made. So we check that the result 226 // memref type is normalized. 227 // TODO: When selective normalization is implemented, handle multiple 228 // results case where some are normalized, some aren't. 229 if (memrefType.getLayout().isIdentity()) 230 resultTypes[operandEn.index()] = memrefType; 231 } 232 }); 233 234 // We create a new function type and modify the function signature with this 235 // new type. 236 newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, 237 /*results=*/resultTypes); 238 } 239 240 // Since we update the function signature, it might affect the result types at 241 // the caller site. Since this result might even be used by the caller 242 // function in ReturnOps, the caller function's signature will also change. 243 // Hence we record the caller function in 'funcOpsToUpdate' to update their 244 // signature as well. 245 llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate; 246 // We iterate over all symbolic uses of the function and update the return 247 // type at the caller site. 248 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp); 249 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 250 Operation *userOp = symbolUse.getUser(); 251 OpBuilder builder(userOp); 252 // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes 253 // that the non-CallOp has no memrefs to be replaced. 254 // TODO: Handle cases where a non-CallOp symbol use of a function deals with 255 // memrefs. 256 auto callOp = dyn_cast<func::CallOp>(userOp); 257 if (!callOp) 258 continue; 259 Operation *newCallOp = 260 builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), 261 resultTypes, userOp->getOperands()); 262 bool replacingMemRefUsesFailed = false; 263 bool returnTypeChanged = false; 264 for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { 265 OpResult oldResult = userOp->getResult(resIndex); 266 OpResult newResult = newCallOp->getResult(resIndex); 267 // This condition ensures that if the result is not of type memref or if 268 // the resulting memref was already having a trivial map layout then we 269 // need not perform any use replacement here. 270 if (oldResult.getType() == newResult.getType()) 271 continue; 272 AffineMap layoutMap = 273 oldResult.getType().cast<MemRefType>().getLayout().getAffineMap(); 274 if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, 275 /*extraIndices=*/{}, 276 /*indexRemap=*/layoutMap, 277 /*extraOperands=*/{}, 278 /*symbolOperands=*/{}, 279 /*domOpFilter=*/nullptr, 280 /*postDomOpFilter=*/nullptr, 281 /*allowNonDereferencingOps=*/true, 282 /*replaceInDeallocOp=*/true))) { 283 // If it failed (due to escapes for example), bail out. 284 // It should never hit this part of the code because it is called by 285 // only those functions which are normalizable. 286 newCallOp->erase(); 287 replacingMemRefUsesFailed = true; 288 break; 289 } 290 returnTypeChanged = true; 291 } 292 if (replacingMemRefUsesFailed) 293 continue; 294 // Replace all uses for other non-memref result types. 295 userOp->replaceAllUsesWith(newCallOp); 296 userOp->erase(); 297 if (returnTypeChanged) { 298 // Since the return type changed it might lead to a change in function's 299 // signature. 300 // TODO: If funcOp doesn't return any memref type then no need to update 301 // signature. 302 // TODO: Further optimization - Check if the memref is indeed part of 303 // ReturnOp at the parentFuncOp and only then updation of signature is 304 // required. 305 // TODO: Extend this for ops that are FunctionOpInterface. This would 306 // require creating an OpInterface for FunctionOpInterface ops. 307 func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>(); 308 funcOpsToUpdate.insert(parentFuncOp); 309 } 310 } 311 // Because external function's signature is already updated in 312 // 'normalizeFuncOpMemRefs()', we don't need to update it here again. 313 if (!funcOp.isExternal()) 314 funcOp.setType(newFuncType); 315 316 // Updating the signature type of those functions which call the current 317 // function. Only if the return type of the current function has a normalized 318 // memref will the caller function become a candidate for signature update. 319 for (func::FuncOp parentFuncOp : funcOpsToUpdate) 320 updateFunctionSignature(parentFuncOp, moduleOp); 321 } 322 323 /// Normalizes the memrefs within a function which includes those arising as a 324 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument 325 /// is used to help update function's signature after normalization. 326 void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, 327 ModuleOp moduleOp) { 328 // Turn memrefs' non-identity layouts maps into ones with identity. Collect 329 // alloc ops first and then process since normalizeMemRef replaces/erases ops 330 // during memref rewriting. 331 SmallVector<memref::AllocOp, 4> allocOps; 332 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); 333 for (memref::AllocOp allocOp : allocOps) 334 (void)normalizeMemRef(&allocOp); 335 336 // We use this OpBuilder to create new memref layout later. 337 OpBuilder b(funcOp); 338 339 FunctionType functionType = funcOp.getFunctionType(); 340 SmallVector<Location> functionArgLocs(llvm::map_range( 341 funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); 342 SmallVector<Type, 8> inputTypes; 343 // Walk over each argument of a function to perform memref normalization (if 344 for (unsigned argIndex : 345 llvm::seq<unsigned>(0, functionType.getNumInputs())) { 346 Type argType = functionType.getInput(argIndex); 347 MemRefType memrefType = argType.dyn_cast<MemRefType>(); 348 // Check whether argument is of MemRef type. Any other argument type can 349 // simply be part of the final function signature. 350 if (!memrefType) { 351 inputTypes.push_back(argType); 352 continue; 353 } 354 // Fetch a new memref type after normalizing the old memref to have an 355 // identity map layout. 356 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 357 /*numSymbolicOperands=*/0); 358 if (newMemRefType == memrefType || funcOp.isExternal()) { 359 // Either memrefType already had an identity map or the map couldn't be 360 // transformed to an identity map. 361 inputTypes.push_back(newMemRefType); 362 continue; 363 } 364 365 // Insert a new temporary argument with the new memref type. 366 BlockArgument newMemRef = funcOp.front().insertArgument( 367 argIndex, newMemRefType, functionArgLocs[argIndex]); 368 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); 369 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 370 // Replace all uses of the old memref. 371 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, 372 /*extraIndices=*/{}, 373 /*indexRemap=*/layoutMap, 374 /*extraOperands=*/{}, 375 /*symbolOperands=*/{}, 376 /*domOpFilter=*/nullptr, 377 /*postDomOpFilter=*/nullptr, 378 /*allowNonDereferencingOps=*/true, 379 /*replaceInDeallocOp=*/true))) { 380 // If it failed (due to escapes for example), bail out. Removing the 381 // temporary argument inserted previously. 382 funcOp.front().eraseArgument(argIndex); 383 continue; 384 } 385 386 // All uses for the argument with old memref type were replaced 387 // successfully. So we remove the old argument now. 388 funcOp.front().eraseArgument(argIndex + 1); 389 } 390 391 // Walk over normalizable operations to normalize memrefs of the operation 392 // results. When `op` has memrefs with affine map in the operation results, 393 // new operation containin normalized memrefs is created. Then, the memrefs 394 // are replaced. `CallOp` is skipped here because it is handled in 395 // `updateFunctionSignature()`. 396 funcOp.walk([&](Operation *op) { 397 if (op->hasTrait<OpTrait::MemRefsNormalizable>() && 398 op->getNumResults() > 0 && !isa<func::CallOp>(op) && 399 !funcOp.isExternal()) { 400 // Create newOp containing normalized memref in the operation result. 401 Operation *newOp = createOpResultsNormalized(funcOp, op); 402 // When all of the operation results have no memrefs or memrefs without 403 // affine map, `newOp` is the same with `op` and following process is 404 // skipped. 405 if (op != newOp) { 406 bool replacingMemRefUsesFailed = false; 407 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) { 408 // Replace all uses of the old memrefs. 409 Value oldMemRef = op->getResult(resIndex); 410 Value newMemRef = newOp->getResult(resIndex); 411 MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>(); 412 // Check whether the operation result is MemRef type. 413 if (!oldMemRefType) 414 continue; 415 MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>(); 416 if (oldMemRefType == newMemRefType) 417 continue; 418 // TODO: Assume single layout map. Multiple maps not supported. 419 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); 420 if (failed(replaceAllMemRefUsesWith(oldMemRef, 421 /*newMemRef=*/newMemRef, 422 /*extraIndices=*/{}, 423 /*indexRemap=*/layoutMap, 424 /*extraOperands=*/{}, 425 /*symbolOperands=*/{}, 426 /*domOpFilter=*/nullptr, 427 /*postDomOpFilter=*/nullptr, 428 /*allowNonDereferencingOps=*/true, 429 /*replaceInDeallocOp=*/true))) { 430 newOp->erase(); 431 replacingMemRefUsesFailed = true; 432 continue; 433 } 434 } 435 if (!replacingMemRefUsesFailed) { 436 // Replace other ops with new op and delete the old op when the 437 // replacement succeeded. 438 op->replaceAllUsesWith(newOp); 439 op->erase(); 440 } 441 } 442 } 443 }); 444 445 // In a normal function, memrefs in the return type signature gets normalized 446 // as a result of normalization of functions arguments, AllocOps or CallOps' 447 // result types. Since an external function doesn't have a body, memrefs in 448 // the return type signature can only get normalized by iterating over the 449 // individual return types. 450 if (funcOp.isExternal()) { 451 SmallVector<Type, 4> resultTypes; 452 for (unsigned resIndex : 453 llvm::seq<unsigned>(0, functionType.getNumResults())) { 454 Type resType = functionType.getResult(resIndex); 455 MemRefType memrefType = resType.dyn_cast<MemRefType>(); 456 // Check whether result is of MemRef type. Any other argument type can 457 // simply be part of the final function signature. 458 if (!memrefType) { 459 resultTypes.push_back(resType); 460 continue; 461 } 462 // Computing a new memref type after normalizing the old memref to have an 463 // identity map layout. 464 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 465 /*numSymbolicOperands=*/0); 466 resultTypes.push_back(newMemRefType); 467 } 468 469 FunctionType newFuncType = 470 FunctionType::get(&getContext(), /*inputs=*/inputTypes, 471 /*results=*/resultTypes); 472 // Setting the new function signature for this external function. 473 funcOp.setType(newFuncType); 474 } 475 updateFunctionSignature(funcOp, moduleOp); 476 } 477 478 /// Create an operation containing normalized memrefs in the operation results. 479 /// When the results of `oldOp` have memrefs with affine map, the memrefs are 480 /// normalized, and new operation containing them in the operation results is 481 /// returned. If all of the results of `oldOp` have no memrefs or memrefs 482 /// without affine map, `oldOp` is returned without modification. 483 Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp, 484 Operation *oldOp) { 485 // Prepare OperationState to create newOp containing normalized memref in 486 // the operation results. 487 OperationState result(oldOp->getLoc(), oldOp->getName()); 488 result.addOperands(oldOp->getOperands()); 489 result.addAttributes(oldOp->getAttrs()); 490 // Add normalized MemRefType to the OperationState. 491 SmallVector<Type, 4> resultTypes; 492 OpBuilder b(funcOp); 493 bool resultTypeNormalized = false; 494 for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) { 495 auto resultType = oldOp->getResult(resIndex).getType(); 496 MemRefType memrefType = resultType.dyn_cast<MemRefType>(); 497 // Check whether the operation result is MemRef type. 498 if (!memrefType) { 499 resultTypes.push_back(resultType); 500 continue; 501 } 502 // Fetch a new memref type after normalizing the old memref. 503 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 504 /*numSymbolicOperands=*/0); 505 if (newMemRefType == memrefType) { 506 // Either memrefType already had an identity map or the map couldn't 507 // be transformed to an identity map. 508 resultTypes.push_back(memrefType); 509 continue; 510 } 511 resultTypes.push_back(newMemRefType); 512 resultTypeNormalized = true; 513 } 514 result.addTypes(resultTypes); 515 // When all of the results of `oldOp` have no memrefs or memrefs without 516 // affine map, `oldOp` is returned without modification. 517 if (resultTypeNormalized) { 518 OpBuilder bb(oldOp); 519 for (auto &oldRegion : oldOp->getRegions()) { 520 Region *newRegion = result.addRegion(); 521 newRegion->takeBody(oldRegion); 522 } 523 return bb.create(result); 524 } 525 return oldOp; 526 } 527