1 //===- NormalizeMemRefs.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an interprocedural pass to normalize memrefs to have 10 // identity layout maps. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Affine/Utils.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 19 #include "llvm/ADT/SmallSet.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "normalize-memrefs" 23 24 using namespace mlir; 25 26 namespace { 27 28 /// All memrefs passed across functions with non-trivial layout maps are 29 /// converted to ones with trivial identity layout ones. 30 /// If all the memref types/uses in a function are normalizable, we treat 31 /// such functions as normalizable. Also, if a normalizable function is known 32 /// to call a non-normalizable function, we treat that function as 33 /// non-normalizable as well. We assume external functions to be normalizable. 34 struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> { 35 void runOnOperation() override; 36 void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp); 37 bool areMemRefsNormalizable(FuncOp funcOp); 38 void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp); 39 void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp, 40 DenseSet<FuncOp> &normalizableFuncs); 41 Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp); 42 }; 43 44 } // namespace 45 46 std::unique_ptr<OperationPass<ModuleOp>> 47 mlir::memref::createNormalizeMemRefsPass() { 48 return std::make_unique<NormalizeMemRefs>(); 49 } 50 51 void NormalizeMemRefs::runOnOperation() { 52 LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n"); 53 ModuleOp moduleOp = getOperation(); 54 // We maintain all normalizable FuncOps in a DenseSet. It is initialized 55 // with all the functions within a module and then functions which are not 56 // normalizable are removed from this set. 57 // TODO: Change this to work on FuncLikeOp once there is an operation 58 // interface for it. 59 DenseSet<FuncOp> normalizableFuncs; 60 // Initialize `normalizableFuncs` with all the functions within a module. 61 moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); }); 62 63 // Traverse through all the functions applying a filter which determines 64 // whether that function is normalizable or not. All callers/callees of 65 // a non-normalizable function will also become non-normalizable even if 66 // they aren't passing any or specific non-normalizable memrefs. So, 67 // functions which calls or get called by a non-normalizable becomes non- 68 // normalizable functions themselves. 69 moduleOp.walk([&](FuncOp funcOp) { 70 if (normalizableFuncs.contains(funcOp)) { 71 if (!areMemRefsNormalizable(funcOp)) { 72 LLVM_DEBUG(llvm::dbgs() 73 << "@" << funcOp.getName() 74 << " contains ops that cannot normalize MemRefs\n"); 75 // Since this function is not normalizable, we set all the caller 76 // functions and the callees of this function as not normalizable. 77 // TODO: Drop this conservative assumption in the future. 78 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 79 normalizableFuncs); 80 } 81 } 82 }); 83 84 LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size() 85 << " functions\n"); 86 // Those functions which can be normalized are subjected to normalization. 87 for (FuncOp &funcOp : normalizableFuncs) 88 normalizeFuncOpMemRefs(funcOp, moduleOp); 89 } 90 91 /// Check whether all the uses of oldMemRef are either dereferencing uses or the 92 /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints 93 /// are satisfied will the value become a candidate for replacement. 94 /// TODO: Extend this for DimOps. 95 static bool isMemRefNormalizable(Value::user_range opUsers) { 96 return llvm::all_of(opUsers, [](Operation *op) { 97 return op->hasTrait<OpTrait::MemRefsNormalizable>(); 98 }); 99 } 100 101 /// Set all the calling functions and the callees of the function as not 102 /// normalizable. 103 void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( 104 FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) { 105 if (!normalizableFuncs.contains(funcOp)) 106 return; 107 108 LLVM_DEBUG( 109 llvm::dbgs() << "@" << funcOp.getName() 110 << " calls or is called by non-normalizable function\n"); 111 normalizableFuncs.erase(funcOp); 112 // Caller of the function. 113 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp); 114 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 115 // TODO: Extend this for ops that are FunctionOpInterface. This would 116 // require creating an OpInterface for FunctionOpInterface ops. 117 FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>(); 118 for (FuncOp &funcOp : normalizableFuncs) { 119 if (parentFuncOp == funcOp) { 120 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 121 normalizableFuncs); 122 break; 123 } 124 } 125 } 126 127 // Functions called by this function. 128 funcOp.walk([&](CallOp callOp) { 129 StringAttr callee = callOp.getCalleeAttr().getAttr(); 130 for (FuncOp &funcOp : normalizableFuncs) { 131 // We compare FuncOp and callee's name. 132 if (callee == funcOp.getNameAttr()) { 133 setCalleesAndCallersNonNormalizable(funcOp, moduleOp, 134 normalizableFuncs); 135 break; 136 } 137 } 138 }); 139 } 140 141 /// Check whether all the uses of AllocOps, CallOps and function arguments of a 142 /// function are either of dereferencing type or are uses in: DeallocOp, CallOp 143 /// or ReturnOp. Only if these constraints are satisfied will the function 144 /// become a candidate for normalization. We follow a conservative approach here 145 /// wherein even if the non-normalizable memref is not a part of the function's 146 /// argument or return type, we still label the entire function as 147 /// non-normalizable. We assume external functions to be normalizable. 148 bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) { 149 // We assume external functions to be normalizable. 150 if (funcOp.isExternal()) 151 return true; 152 153 if (funcOp 154 .walk([&](memref::AllocOp allocOp) -> WalkResult { 155 Value oldMemRef = allocOp.getResult(); 156 if (!isMemRefNormalizable(oldMemRef.getUsers())) 157 return WalkResult::interrupt(); 158 return WalkResult::advance(); 159 }) 160 .wasInterrupted()) 161 return false; 162 163 if (funcOp 164 .walk([&](CallOp callOp) -> WalkResult { 165 for (unsigned resIndex : 166 llvm::seq<unsigned>(0, callOp.getNumResults())) { 167 Value oldMemRef = callOp.getResult(resIndex); 168 if (oldMemRef.getType().isa<MemRefType>()) 169 if (!isMemRefNormalizable(oldMemRef.getUsers())) 170 return WalkResult::interrupt(); 171 } 172 return WalkResult::advance(); 173 }) 174 .wasInterrupted()) 175 return false; 176 177 for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) { 178 BlockArgument oldMemRef = funcOp.getArgument(argIndex); 179 if (oldMemRef.getType().isa<MemRefType>()) 180 if (!isMemRefNormalizable(oldMemRef.getUsers())) 181 return false; 182 } 183 184 return true; 185 } 186 187 /// Fetch the updated argument list and result of the function and update the 188 /// function signature. This updates the function's return type at the caller 189 /// site and in case the return type is a normalized memref then it updates 190 /// the calling function's signature. 191 /// TODO: An update to the calling function signature is required only if the 192 /// returned value is in turn used in ReturnOp of the calling function. 193 void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp, 194 ModuleOp moduleOp) { 195 FunctionType functionType = funcOp.getType(); 196 SmallVector<Type, 4> resultTypes; 197 FunctionType newFuncType; 198 resultTypes = llvm::to_vector<4>(functionType.getResults()); 199 200 // External function's signature was already updated in 201 // 'normalizeFuncOpMemRefs()'. 202 if (!funcOp.isExternal()) { 203 SmallVector<Type, 8> argTypes; 204 for (const auto &argEn : llvm::enumerate(funcOp.getArguments())) 205 argTypes.push_back(argEn.value().getType()); 206 207 // Traverse ReturnOps to check if an update to the return type in the 208 // function signature is required. 209 funcOp.walk([&](ReturnOp returnOp) { 210 for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) { 211 Type opType = operandEn.value().getType(); 212 MemRefType memrefType = opType.dyn_cast<MemRefType>(); 213 // If type is not memref or if the memref type is same as that in 214 // function's return signature then no update is required. 215 if (!memrefType || memrefType == resultTypes[operandEn.index()]) 216 continue; 217 // Update function's return type signature. 218 // Return type gets normalized either as a result of function argument 219 // normalization, AllocOp normalization or an update made at CallOp. 220 // There can be many call flows inside a function and an update to a 221 // specific ReturnOp has not yet been made. So we check that the result 222 // memref type is normalized. 223 // TODO: When selective normalization is implemented, handle multiple 224 // results case where some are normalized, some aren't. 225 if (memrefType.getLayout().isIdentity()) 226 resultTypes[operandEn.index()] = memrefType; 227 } 228 }); 229 230 // We create a new function type and modify the function signature with this 231 // new type. 232 newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes, 233 /*results=*/resultTypes); 234 } 235 236 // Since we update the function signature, it might affect the result types at 237 // the caller site. Since this result might even be used by the caller 238 // function in ReturnOps, the caller function's signature will also change. 239 // Hence we record the caller function in 'funcOpsToUpdate' to update their 240 // signature as well. 241 llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate; 242 // We iterate over all symbolic uses of the function and update the return 243 // type at the caller site. 244 Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp); 245 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 246 Operation *userOp = symbolUse.getUser(); 247 OpBuilder builder(userOp); 248 // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes 249 // that the non-CallOp has no memrefs to be replaced. 250 // TODO: Handle cases where a non-CallOp symbol use of a function deals with 251 // memrefs. 252 auto callOp = dyn_cast<CallOp>(userOp); 253 if (!callOp) 254 continue; 255 Operation *newCallOp = 256 builder.create<CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), 257 resultTypes, userOp->getOperands()); 258 bool replacingMemRefUsesFailed = false; 259 bool returnTypeChanged = false; 260 for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { 261 OpResult oldResult = userOp->getResult(resIndex); 262 OpResult newResult = newCallOp->getResult(resIndex); 263 // This condition ensures that if the result is not of type memref or if 264 // the resulting memref was already having a trivial map layout then we 265 // need not perform any use replacement here. 266 if (oldResult.getType() == newResult.getType()) 267 continue; 268 AffineMap layoutMap = 269 oldResult.getType().cast<MemRefType>().getLayout().getAffineMap(); 270 if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult, 271 /*extraIndices=*/{}, 272 /*indexRemap=*/layoutMap, 273 /*extraOperands=*/{}, 274 /*symbolOperands=*/{}, 275 /*domOpFilter=*/nullptr, 276 /*postDomOpFilter=*/nullptr, 277 /*allowNonDereferencingOps=*/true, 278 /*replaceInDeallocOp=*/true))) { 279 // If it failed (due to escapes for example), bail out. 280 // It should never hit this part of the code because it is called by 281 // only those functions which are normalizable. 282 newCallOp->erase(); 283 replacingMemRefUsesFailed = true; 284 break; 285 } 286 returnTypeChanged = true; 287 } 288 if (replacingMemRefUsesFailed) 289 continue; 290 // Replace all uses for other non-memref result types. 291 userOp->replaceAllUsesWith(newCallOp); 292 userOp->erase(); 293 if (returnTypeChanged) { 294 // Since the return type changed it might lead to a change in function's 295 // signature. 296 // TODO: If funcOp doesn't return any memref type then no need to update 297 // signature. 298 // TODO: Further optimization - Check if the memref is indeed part of 299 // ReturnOp at the parentFuncOp and only then updation of signature is 300 // required. 301 // TODO: Extend this for ops that are FunctionOpInterface. This would 302 // require creating an OpInterface for FunctionOpInterface ops. 303 FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>(); 304 funcOpsToUpdate.insert(parentFuncOp); 305 } 306 } 307 // Because external function's signature is already updated in 308 // 'normalizeFuncOpMemRefs()', we don't need to update it here again. 309 if (!funcOp.isExternal()) 310 funcOp.setType(newFuncType); 311 312 // Updating the signature type of those functions which call the current 313 // function. Only if the return type of the current function has a normalized 314 // memref will the caller function become a candidate for signature update. 315 for (FuncOp parentFuncOp : funcOpsToUpdate) 316 updateFunctionSignature(parentFuncOp, moduleOp); 317 } 318 319 /// Normalizes the memrefs within a function which includes those arising as a 320 /// result of AllocOps, CallOps and function's argument. The ModuleOp argument 321 /// is used to help update function's signature after normalization. 322 void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp, 323 ModuleOp moduleOp) { 324 // Turn memrefs' non-identity layouts maps into ones with identity. Collect 325 // alloc ops first and then process since normalizeMemRef replaces/erases ops 326 // during memref rewriting. 327 SmallVector<memref::AllocOp, 4> allocOps; 328 funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); 329 for (memref::AllocOp allocOp : allocOps) 330 (void)normalizeMemRef(&allocOp); 331 332 // We use this OpBuilder to create new memref layout later. 333 OpBuilder b(funcOp); 334 335 FunctionType functionType = funcOp.getType(); 336 SmallVector<Location> functionArgLocs(llvm::map_range( 337 funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); })); 338 SmallVector<Type, 8> inputTypes; 339 // Walk over each argument of a function to perform memref normalization (if 340 for (unsigned argIndex : 341 llvm::seq<unsigned>(0, functionType.getNumInputs())) { 342 Type argType = functionType.getInput(argIndex); 343 MemRefType memrefType = argType.dyn_cast<MemRefType>(); 344 // Check whether argument is of MemRef type. Any other argument type can 345 // simply be part of the final function signature. 346 if (!memrefType) { 347 inputTypes.push_back(argType); 348 continue; 349 } 350 // Fetch a new memref type after normalizing the old memref to have an 351 // identity map layout. 352 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 353 /*numSymbolicOperands=*/0); 354 if (newMemRefType == memrefType || funcOp.isExternal()) { 355 // Either memrefType already had an identity map or the map couldn't be 356 // transformed to an identity map. 357 inputTypes.push_back(newMemRefType); 358 continue; 359 } 360 361 // Insert a new temporary argument with the new memref type. 362 BlockArgument newMemRef = funcOp.front().insertArgument( 363 argIndex, newMemRefType, functionArgLocs[argIndex]); 364 BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1); 365 AffineMap layoutMap = memrefType.getLayout().getAffineMap(); 366 // Replace all uses of the old memref. 367 if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef, 368 /*extraIndices=*/{}, 369 /*indexRemap=*/layoutMap, 370 /*extraOperands=*/{}, 371 /*symbolOperands=*/{}, 372 /*domOpFilter=*/nullptr, 373 /*postDomOpFilter=*/nullptr, 374 /*allowNonDereferencingOps=*/true, 375 /*replaceInDeallocOp=*/true))) { 376 // If it failed (due to escapes for example), bail out. Removing the 377 // temporary argument inserted previously. 378 funcOp.front().eraseArgument(argIndex); 379 continue; 380 } 381 382 // All uses for the argument with old memref type were replaced 383 // successfully. So we remove the old argument now. 384 funcOp.front().eraseArgument(argIndex + 1); 385 } 386 387 // Walk over normalizable operations to normalize memrefs of the operation 388 // results. When `op` has memrefs with affine map in the operation results, 389 // new operation containin normalized memrefs is created. Then, the memrefs 390 // are replaced. `CallOp` is skipped here because it is handled in 391 // `updateFunctionSignature()`. 392 funcOp.walk([&](Operation *op) { 393 if (op->hasTrait<OpTrait::MemRefsNormalizable>() && 394 op->getNumResults() > 0 && !isa<CallOp>(op) && !funcOp.isExternal()) { 395 // Create newOp containing normalized memref in the operation result. 396 Operation *newOp = createOpResultsNormalized(funcOp, op); 397 // When all of the operation results have no memrefs or memrefs without 398 // affine map, `newOp` is the same with `op` and following process is 399 // skipped. 400 if (op != newOp) { 401 bool replacingMemRefUsesFailed = false; 402 for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) { 403 // Replace all uses of the old memrefs. 404 Value oldMemRef = op->getResult(resIndex); 405 Value newMemRef = newOp->getResult(resIndex); 406 MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>(); 407 // Check whether the operation result is MemRef type. 408 if (!oldMemRefType) 409 continue; 410 MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>(); 411 if (oldMemRefType == newMemRefType) 412 continue; 413 // TODO: Assume single layout map. Multiple maps not supported. 414 AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap(); 415 if (failed(replaceAllMemRefUsesWith(oldMemRef, 416 /*newMemRef=*/newMemRef, 417 /*extraIndices=*/{}, 418 /*indexRemap=*/layoutMap, 419 /*extraOperands=*/{}, 420 /*symbolOperands=*/{}, 421 /*domOpFilter=*/nullptr, 422 /*postDomOpFilter=*/nullptr, 423 /*allowNonDereferencingOps=*/true, 424 /*replaceInDeallocOp=*/true))) { 425 newOp->erase(); 426 replacingMemRefUsesFailed = true; 427 continue; 428 } 429 } 430 if (!replacingMemRefUsesFailed) { 431 // Replace other ops with new op and delete the old op when the 432 // replacement succeeded. 433 op->replaceAllUsesWith(newOp); 434 op->erase(); 435 } 436 } 437 } 438 }); 439 440 // In a normal function, memrefs in the return type signature gets normalized 441 // as a result of normalization of functions arguments, AllocOps or CallOps' 442 // result types. Since an external function doesn't have a body, memrefs in 443 // the return type signature can only get normalized by iterating over the 444 // individual return types. 445 if (funcOp.isExternal()) { 446 SmallVector<Type, 4> resultTypes; 447 for (unsigned resIndex : 448 llvm::seq<unsigned>(0, functionType.getNumResults())) { 449 Type resType = functionType.getResult(resIndex); 450 MemRefType memrefType = resType.dyn_cast<MemRefType>(); 451 // Check whether result is of MemRef type. Any other argument type can 452 // simply be part of the final function signature. 453 if (!memrefType) { 454 resultTypes.push_back(resType); 455 continue; 456 } 457 // Computing a new memref type after normalizing the old memref to have an 458 // identity map layout. 459 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 460 /*numSymbolicOperands=*/0); 461 resultTypes.push_back(newMemRefType); 462 } 463 464 FunctionType newFuncType = 465 FunctionType::get(&getContext(), /*inputs=*/inputTypes, 466 /*results=*/resultTypes); 467 // Setting the new function signature for this external function. 468 funcOp.setType(newFuncType); 469 } 470 updateFunctionSignature(funcOp, moduleOp); 471 } 472 473 /// Create an operation containing normalized memrefs in the operation results. 474 /// When the results of `oldOp` have memrefs with affine map, the memrefs are 475 /// normalized, and new operation containing them in the operation results is 476 /// returned. If all of the results of `oldOp` have no memrefs or memrefs 477 /// without affine map, `oldOp` is returned without modification. 478 Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp, 479 Operation *oldOp) { 480 // Prepare OperationState to create newOp containing normalized memref in 481 // the operation results. 482 OperationState result(oldOp->getLoc(), oldOp->getName()); 483 result.addOperands(oldOp->getOperands()); 484 result.addAttributes(oldOp->getAttrs()); 485 // Add normalized MemRefType to the OperationState. 486 SmallVector<Type, 4> resultTypes; 487 OpBuilder b(funcOp); 488 bool resultTypeNormalized = false; 489 for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) { 490 auto resultType = oldOp->getResult(resIndex).getType(); 491 MemRefType memrefType = resultType.dyn_cast<MemRefType>(); 492 // Check whether the operation result is MemRef type. 493 if (!memrefType) { 494 resultTypes.push_back(resultType); 495 continue; 496 } 497 // Fetch a new memref type after normalizing the old memref. 498 MemRefType newMemRefType = normalizeMemRefType(memrefType, b, 499 /*numSymbolicOperands=*/0); 500 if (newMemRefType == memrefType) { 501 // Either memrefType already had an identity map or the map couldn't 502 // be transformed to an identity map. 503 resultTypes.push_back(memrefType); 504 continue; 505 } 506 resultTypes.push_back(newMemRefType); 507 resultTypeNormalized = true; 508 } 509 result.addTypes(resultTypes); 510 // When all of the results of `oldOp` have no memrefs or memrefs without 511 // affine map, `oldOp` is returned without modification. 512 if (resultTypeNormalized) { 513 OpBuilder bb(oldOp); 514 for (auto &oldRegion : oldOp->getRegions()) { 515 Region *newRegion = result.addRegion(); 516 newRegion->takeBody(oldRegion); 517 } 518 return bb.createOperation(result); 519 } 520 return oldOp; 521 } 522