1 //===- NormalizeMemRefs.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an interprocedural pass to normalize memrefs to have 10 // identity layout maps. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/Utils.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/MemRef/IR/MemRef.h" 19 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 20 #include "llvm/ADT/SmallSet.h" 21 #include "llvm/Support/Debug.h" 22 23 #define DEBUG_TYPE "normalize-memrefs" 24 25 using namespace mlir; 26 27 namespace { 28 29 /// All memrefs passed across functions with non-trivial layout maps are 30 /// converted to ones with trivial identity layout ones. 31 /// If all the memref types/uses in a function are normalizable, we treat 32 /// such functions as normalizable. Also, if a normalizable function is known 33 /// to call a non-normalizable function, we treat that function as 34 /// non-normalizable as well. We assume external functions to be normalizable. 35 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> { 36 void runOnOperation() override; 37 void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp); 38 bool areMemRefsNormalizable(FuncOp funcOp); 39 void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp); 40 void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp, 41 DenseSet<FuncOp> &normalizableFuncs); 42 Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp); 43 }; 44 45 } // namespace 46 47 std::unique_ptr<OperationPass<ModuleOp>> 48 mlir::memref::createNormalizeMemRefsPass() { 49 return std::make_unique<NormalizeMemRefs>(); 50 } 51 52 void NormalizeMemRefs::runOnOperation() { 53 LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n"); 54 ModuleOp moduleOp = getOperation(); 55 // We maintain all normalizable FuncOps in a DenseSet. It is initialized 56 // with all the functions within a module and then functions which are not 57 // normalizable are removed from this set. 58 // TODO: Change this to work on FuncLikeOp once there is an operation 59 // interface for it. 60 DenseSet<FuncOp> normalizableFuncs; 61 // Initialize `normalizableFuncs` with all the functions within a module. 62 moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); 63 64 // Traverse through all the functions applying a filter which determines 65 // whether that function is normalizable or not. All callers/callees of 66 // a non-normalizable function will also become non-normalizable even if 67 // they aren't passing any or specific non-normalizable memrefs. So, 68 // functions which calls or get called by a non-normalizable becomes non- 69 // normalizable functions themselves. 70 moduleOp.walk([&](FuncOp funcOp) { 71 if (normalizableFuncs.contains(funcOp)) { 72 if (!areMemRefsNormalizable(funcOp)) { 73 LLVM_DEBUG(llvm::dbgs() 74 << "@" << funcOp.getName() 75 << " contains ops that cannot normalize MemRefs\n"); 76 // Since this function is not normalizable, we set all the caller 77 // functions and the callees of this function as not normalizable. 78 // TODO: Drop this conservative assumption in the future. 79 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 80 normalizableFuncs); 81 } 82 } 83 }); 84 85 LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() 86 << " functions\n"); 87 // Those functions which can be normalized are subjected to normalization. 88 for (FuncOp &funcOp : normalizableFuncs) 89 normalizeFuncOpMemRefs(funcOp, moduleOp); 90 } 91 92 /// Check whether all the uses of oldMemRef are either dereferencing uses or the 93 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints 94 /// are satisfied will the value become a candidate for replacement. 95 /// TODO: Extend this for DimOps. 96 static bool isMemRefNormalizable(Value::user_range opUsers) { 97 return llvm::all_of(opUsers, [](Operation *op) { 98 return op->hasTrait<OpTrait::MemRefsNormalizable>(); 99 }); 100 } 101 102 /// Set all the calling functions and the callees of the function as not 103 /// normalizable. 104 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( 105 FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) { 106 if (!normalizableFuncs.contains(funcOp)) 107 return; 108 109 LLVM_DEBUG( 110 llvm::dbgs() << "@" << funcOp.getName() 111 << " calls or is called by non-normalizable function\n"); 112 normalizableFuncs.erase(funcOp); 113 // Caller of the function. 114 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp); 115 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 116 // TODO: Extend this for ops that are FunctionOpInterface. This would 117 // require creating an OpInterface for FunctionOpInterface ops. 118 FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>(); 119 for (FuncOp &funcOp : normalizableFuncs) { 120 if (parentFuncOp == funcOp) { 121 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 122 normalizableFuncs); 123 break; 124 } 125 } 126 } 127 128 // Functions called by this function. 129 funcOp.walk([&](func::CallOp callOp) { 130 StringAttr callee = callOp.getCalleeAttr().getAttr(); 131 for (FuncOp &funcOp : normalizableFuncs) { 132 // We compare FuncOp and callee's name. 133 if (callee == funcOp.getNameAttr()) { 134 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 135 normalizableFuncs); 136 break; 137 } 138 } 139 }); 140 } 141 142 /// Check whether all the uses of AllocOps, CallOps and function arguments of a 143 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp 144 /// or ReturnOp. Only if these constraints are satisfied will the function 145 /// become a candidate for normalization. We follow a conservative approach here 146 /// wherein even if the non-normalizable memref is not a part of the function's 147 /// argument or return type, we still label the entire function as 148 /// non-normalizable. We assume external functions to be normalizable. 149 bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) { 150 // We assume external functions to be normalizable. 151 if (funcOp.isExternal()) 152 return true; 153 154 if (funcOp 155 .walk([&](memref::AllocOp allocOp) -> WalkResult { 156 Value oldMemRef = allocOp.getResult(); 157 if (!isMemRefNormalizable(oldMemRef.getUsers())) 158 return WalkResult::interrupt(); 159 return WalkResult::advance(); 160 }) 161 .wasInterrupted()) 162 return false; 163 164 if (funcOp 165 .walk([&](func::CallOp callOp) -> WalkResult { 166 for (unsigned resIndex : 167 llvm::seq<unsigned>(0, callOp.getNumResults())) { 168 Value oldMemRef = callOp.getResult(resIndex); 169 if (oldMemRef.getType().isa<MemRefType>()) 170 if (!isMemRefNormalizable(oldMemRef.getUsers())) 171 return WalkResult::interrupt(); 172 } 173 return WalkResult::advance(); 174 }) 175 .wasInterrupted()) 176 return false; 177 178 for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 179 BlockArgument oldMemRef = funcOp.getArgument(argIndex); 180 if (oldMemRef.getType().isa<MemRefType>()) 181 if (!isMemRefNormalizable(oldMemRef.getUsers())) 182 return false; 183 } 184 185 return true; 186 } 187 188 /// Fetch the updated argument list and result of the function and update the 189 /// function signature. This updates the function's return type at the caller 190 /// site and in case the return type is a normalized memref then it updates 191 /// the calling function's signature. 192 /// TODO: An update to the calling function signature is required only if the 193 /// returned value is in turn used in ReturnOp of the calling function. 194 void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp, 195 ModuleOp moduleOp) { 196 FunctionType functionType = funcOp.getType(); 197 SmallVector<Type, 4> resultTypes; 198 FunctionType newFuncType; 199 resultTypes = llvm::to_vector<4>(functionType.getResults()); 200 201 // External function's signature was already updated in 202 // 'normalizeFuncOpMemRefs()'. 203 if (!funcOp.isExternal()) { 204 SmallVector<Type, 8> argTypes; 205 for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) 206 argTypes.push_back(argEn.value().getType()); 207 208 // Traverse ReturnOps to check if an update to the return type in the 209 // function signature is required. 210 funcOp.walk([&](func::ReturnOp returnOp) { 211 for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { 212 Type opType = operandEn.value().getType(); 213 MemRefType memrefType = opType.dyn_cast<MemRefType>(); 214 // If type is not memref or if the memref type is same as that in 215 // function's return signature then no update is required. 216 if (!memrefType || memrefType == resultTypes[operandEn.index()]) 217 continue; 218 // Update function's return type signature. 219 // Return type gets normalized either as a result of function argument 220 // normalization, AllocOp normalization or an update made at CallOp. 221 // There can be many call flows inside a function and an update to a 222 // specific ReturnOp has not yet been made. So we check that the result 223 // memref type is normalized. 224 // TODO: When selective normalization is implemented, handle multiple 225 // results case where some are normalized, some aren't. 226 if (memrefType.getLayout().isIdentity()) 227 resultTypes[operandEn.index()] = memrefType; 228 } 229 }); 230 231 // We create a new function type and modify the function signature with this 232 // new type. 233 newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, 234 /*results=*/resultTypes); 235 } 236 237 // Since we update the function signature, it might affect the result types at 238 // the caller site. Since this result might even be used by the caller 239 // function in ReturnOps, the caller function's signature will also change. 240 // Hence we record the caller function in 'funcOpsToUpdate' to update their 241 // signature as well. 242 llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate; 243 // We iterate over all symbolic uses of the function and update the return 244 // type at the caller site. 245 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp); 246 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 247 Operation *userOp = symbolUse.getUser(); 248 OpBuilder builder(userOp); 249 // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes 250 // that the non-CallOp has no memrefs to be replaced. 251 // TODO: Handle cases where a non-CallOp symbol use of a function deals with 252 // memrefs. 253 auto callOp = dyn_cast<func::CallOp>(userOp); 254 if (!callOp) 255 continue; 256 Operation *newCallOp = 257 builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), 258 resultTypes, userOp->getOperands()); 259 bool replacingMemRefUsesFailed = false; 260 bool returnTypeChanged = false; 261 for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { 262 OpResult oldResult = userOp->getResult(resIndex); 263 OpResult newResult = newCallOp->getResult(resIndex); 264 // This condition ensures that if the result is not of type memref or if 265 // the resulting memref was already having a trivial map layout then we 266 // need not perform any use replacement here. 267 if (oldResult.getType() == newResult.getType()) 268 continue; 269 AffineMap layoutMap = 270 oldResult.getType().cast<MemRefType>().getLayout().getAffineMap(); 271 if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, 272 /*extraIndices=*/{}, 273 /*indexRemap=*/layoutMap, 274 /*extraOperands=*/{}, 275 /*symbolOperands=*/{}, 276 /*domOpFilter=*/nullptr, 277 /*postDomOpFilter=*/nullptr, 278 /*allowNonDereferencingOps=*/true, 279 /*replaceInDeallocOp=*/true))) { 280 // If it failed (due to escapes for example), bail out. 281 // It should never hit this part of the code because it is called by 282 // only those functions which are normalizable. 283 newCallOp->erase(); 284 replacingMemRefUsesFailed = true; 285 break; 286 } 287 returnTypeChanged = true; 288 } 289 if (replacingMemRefUsesFailed) 290 continue; 291 // Replace all uses for other non-memref result types. 292 userOp->replaceAllUsesWith(newCallOp); 293 userOp->erase(); 294 if (returnTypeChanged) { 295 // Since the return type changed it might lead to a change in function's 296 // signature. 297 // TODO: If funcOp doesn't return any memref type then no need to update 298 // signature. 299 // TODO: Further optimization - Check if the memref is indeed part of 300 // ReturnOp at the parentFuncOp and only then updation of signature is 301 // required. 302 // TODO: Extend this for ops that are FunctionOpInterface. This would 303 // require creating an OpInterface for FunctionOpInterface ops. 304 FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>(); 305 funcOpsToUpdate.insert(parentFuncOp); 306 } 307 } 308 // Because external function's signature is already updated in 309 // 'normalizeFuncOpMemRefs()', we don't need to update it here again. 310 if (!funcOp.isExternal()) 311 funcOp.setType(newFuncType); 312 313 // Updating the signature type of those functions which call the current 314 // function. Only if the return type of the current function has a normalized 315 // memref will the caller function become a candidate for signature update. 316 for (FuncOp parentFuncOp : funcOpsToUpdate) 317 updateFunctionSignature(parentFuncOp, moduleOp); 318 } 319 320 /// Normalizes the memrefs within a function which includes those arising as a 321 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument 322 /// is used to help update function's signature after normalization. 323 void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp, 324 ModuleOp moduleOp) { 325 // Turn memrefs' non-identity layouts maps into ones with identity. Collect 326 // alloc ops first and then process since normalizeMemRef replaces/erases ops 327 // during memref rewriting. 328 SmallVector<memref::AllocOp, 4> allocOps; 329 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); 330 for (memref::AllocOp allocOp : allocOps) 331 (void)normalizeMemRef(&allocOp); 332 333 // We use this OpBuilder to create new memref layout later. 334 OpBuilder b(funcOp); 335 336 FunctionType functionType = funcOp.getType(); 337 SmallVector<Location> functionArgLocs(llvm::map_range( 338 funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); 339 SmallVector<Type, 8> inputTypes; 340 // Walk over each argument of a function to perform memref normalization (if 341 for (unsigned argIndex : 342 llvm::seq<unsigned>(0, functionType.getNumInputs())) { 343 Type argType = functionType.getInput(argIndex); 344 MemRefType memrefType = argType.dyn_cast<MemRefType>(); 345 // Check whether argument is of MemRef type. Any other argument type can 346 // simply be part of the final function signature. 347 if (!memrefType) { 348 inputTypes.push_back(argType); 349 continue; 350 } 351 // Fetch a new memref type after normalizing the old memref to have an 352 // identity map layout. 353 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 354 /*numSymbolicOperands=*/0); 355 if (newMemRefType == memrefType || funcOp.isExternal()) { 356 // Either memrefType already had an identity map or the map couldn't be 357 // transformed to an identity map. 358 inputTypes.push_back(newMemRefType); 359 continue; 360 } 361 362 // Insert a new temporary argument with the new memref type. 363 BlockArgument newMemRef = funcOp.front().insertArgument( 364 argIndex, newMemRefType, functionArgLocs[argIndex]); 365 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); 366 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 367 // Replace all uses of the old memref. 368 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, 369 /*extraIndices=*/{}, 370 /*indexRemap=*/layoutMap, 371 /*extraOperands=*/{}, 372 /*symbolOperands=*/{}, 373 /*domOpFilter=*/nullptr, 374 /*postDomOpFilter=*/nullptr, 375 /*allowNonDereferencingOps=*/true, 376 /*replaceInDeallocOp=*/true))) { 377 // If it failed (due to escapes for example), bail out. Removing the 378 // temporary argument inserted previously. 379 funcOp.front().eraseArgument(argIndex); 380 continue; 381 } 382 383 // All uses for the argument with old memref type were replaced 384 // successfully. So we remove the old argument now. 385 funcOp.front().eraseArgument(argIndex + 1); 386 } 387 388 // Walk over normalizable operations to normalize memrefs of the operation 389 // results. When `op` has memrefs with affine map in the operation results, 390 // new operation containin normalized memrefs is created. Then, the memrefs 391 // are replaced. `CallOp` is skipped here because it is handled in 392 // `updateFunctionSignature()`. 393 funcOp.walk([&](Operation *op) { 394 if (op->hasTrait<OpTrait::MemRefsNormalizable>() && 395 op->getNumResults() > 0 && !isa<func::CallOp>(op) && 396 !funcOp.isExternal()) { 397 // Create newOp containing normalized memref in the operation result. 398 Operation *newOp = createOpResultsNormalized(funcOp, op); 399 // When all of the operation results have no memrefs or memrefs without 400 // affine map, `newOp` is the same with `op` and following process is 401 // skipped. 402 if (op != newOp) { 403 bool replacingMemRefUsesFailed = false; 404 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) { 405 // Replace all uses of the old memrefs. 406 Value oldMemRef = op->getResult(resIndex); 407 Value newMemRef = newOp->getResult(resIndex); 408 MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>(); 409 // Check whether the operation result is MemRef type. 410 if (!oldMemRefType) 411 continue; 412 MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>(); 413 if (oldMemRefType == newMemRefType) 414 continue; 415 // TODO: Assume single layout map. Multiple maps not supported. 416 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); 417 if (failed(replaceAllMemRefUsesWith(oldMemRef, 418 /*newMemRef=*/newMemRef, 419 /*extraIndices=*/{}, 420 /*indexRemap=*/layoutMap, 421 /*extraOperands=*/{}, 422 /*symbolOperands=*/{}, 423 /*domOpFilter=*/nullptr, 424 /*postDomOpFilter=*/nullptr, 425 /*allowNonDereferencingOps=*/true, 426 /*replaceInDeallocOp=*/true))) { 427 newOp->erase(); 428 replacingMemRefUsesFailed = true; 429 continue; 430 } 431 } 432 if (!replacingMemRefUsesFailed) { 433 // Replace other ops with new op and delete the old op when the 434 // replacement succeeded. 435 op->replaceAllUsesWith(newOp); 436 op->erase(); 437 } 438 } 439 } 440 }); 441 442 // In a normal function, memrefs in the return type signature gets normalized 443 // as a result of normalization of functions arguments, AllocOps or CallOps' 444 // result types. Since an external function doesn't have a body, memrefs in 445 // the return type signature can only get normalized by iterating over the 446 // individual return types. 447 if (funcOp.isExternal()) { 448 SmallVector<Type, 4> resultTypes; 449 for (unsigned resIndex : 450 llvm::seq<unsigned>(0, functionType.getNumResults())) { 451 Type resType = functionType.getResult(resIndex); 452 MemRefType memrefType = resType.dyn_cast<MemRefType>(); 453 // Check whether result is of MemRef type. Any other argument type can 454 // simply be part of the final function signature. 455 if (!memrefType) { 456 resultTypes.push_back(resType); 457 continue; 458 } 459 // Computing a new memref type after normalizing the old memref to have an 460 // identity map layout. 461 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 462 /*numSymbolicOperands=*/0); 463 resultTypes.push_back(newMemRefType); 464 } 465 466 FunctionType newFuncType = 467 FunctionType::get(&getContext(), /*inputs=*/inputTypes, 468 /*results=*/resultTypes); 469 // Setting the new function signature for this external function. 470 funcOp.setType(newFuncType); 471 } 472 updateFunctionSignature(funcOp, moduleOp); 473 } 474 475 /// Create an operation containing normalized memrefs in the operation results. 476 /// When the results of `oldOp` have memrefs with affine map, the memrefs are 477 /// normalized, and new operation containing them in the operation results is 478 /// returned. If all of the results of `oldOp` have no memrefs or memrefs 479 /// without affine map, `oldOp` is returned without modification. 480 Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp, 481 Operation *oldOp) { 482 // Prepare OperationState to create newOp containing normalized memref in 483 // the operation results. 484 OperationState result(oldOp->getLoc(), oldOp->getName()); 485 result.addOperands(oldOp->getOperands()); 486 result.addAttributes(oldOp->getAttrs()); 487 // Add normalized MemRefType to the OperationState. 488 SmallVector<Type, 4> resultTypes; 489 OpBuilder b(funcOp); 490 bool resultTypeNormalized = false; 491 for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) { 492 auto resultType = oldOp->getResult(resIndex).getType(); 493 MemRefType memrefType = resultType.dyn_cast<MemRefType>(); 494 // Check whether the operation result is MemRef type. 495 if (!memrefType) { 496 resultTypes.push_back(resultType); 497 continue; 498 } 499 // Fetch a new memref type after normalizing the old memref. 500 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 501 /*numSymbolicOperands=*/0); 502 if (newMemRefType == memrefType) { 503 // Either memrefType already had an identity map or the map couldn't 504 // be transformed to an identity map. 505 resultTypes.push_back(memrefType); 506 continue; 507 } 508 resultTypes.push_back(newMemRefType); 509 resultTypeNormalized = true; 510 } 511 result.addTypes(resultTypes); 512 // When all of the results of `oldOp` have no memrefs or memrefs without 513 // affine map, `oldOp` is returned without modification. 514 if (resultTypeNormalized) { 515 OpBuilder bb(oldOp); 516 for (auto &oldRegion : oldOp->getRegions()) { 517 Region *newRegion = result.addRegion(); 518 newRegion->takeBody(oldRegion); 519 } 520 return bb.createOperation(result); 521 } 522 return oldOp; 523 } 524