12e2c0738SRiver Riddle //===- NormalizeMemRefs.cpp -----------------------------------------------===//
22e2c0738SRiver Riddle //
32e2c0738SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42e2c0738SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
52e2c0738SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62e2c0738SRiver Riddle //
72e2c0738SRiver Riddle //===----------------------------------------------------------------------===//
82e2c0738SRiver Riddle //
92e2c0738SRiver Riddle // This file implements an interprocedural pass to normalize memrefs to have
102e2c0738SRiver Riddle // identity layout maps.
112e2c0738SRiver Riddle //
122e2c0738SRiver Riddle //===----------------------------------------------------------------------===//
132e2c0738SRiver Riddle 
142e2c0738SRiver Riddle #include "PassDetail.h"
152e2c0738SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h"
171f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
182e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
192e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/Transforms/Passes.h"
202e2c0738SRiver Riddle #include "llvm/ADT/SmallSet.h"
212e2c0738SRiver Riddle #include "llvm/Support/Debug.h"
222e2c0738SRiver Riddle 
232e2c0738SRiver Riddle #define DEBUG_TYPE "normalize-memrefs"
242e2c0738SRiver Riddle 
252e2c0738SRiver Riddle using namespace mlir;
262e2c0738SRiver Riddle 
272e2c0738SRiver Riddle namespace {
282e2c0738SRiver Riddle 
292e2c0738SRiver Riddle /// All memrefs passed across functions with non-trivial layout maps are
302e2c0738SRiver Riddle /// converted to ones with trivial identity layout ones.
312e2c0738SRiver Riddle /// If all the memref types/uses in a function are normalizable, we treat
322e2c0738SRiver Riddle /// such functions as normalizable. Also, if a normalizable function is known
332e2c0738SRiver Riddle /// to call a non-normalizable function, we treat that function as
342e2c0738SRiver Riddle /// non-normalizable as well. We assume external functions to be normalizable.
352e2c0738SRiver Riddle struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
362e2c0738SRiver Riddle   void runOnOperation() override;
37*58ceae95SRiver Riddle   void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
38*58ceae95SRiver Riddle   bool areMemRefsNormalizable(func::FuncOp funcOp);
39*58ceae95SRiver Riddle   void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
40*58ceae95SRiver Riddle   void setCalleesAndCallersNonNormalizable(
41*58ceae95SRiver Riddle       func::FuncOp funcOp, ModuleOp moduleOp,
42*58ceae95SRiver Riddle       DenseSet<func::FuncOp> &normalizableFuncs);
43*58ceae95SRiver Riddle   Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp);
442e2c0738SRiver Riddle };
452e2c0738SRiver Riddle 
462e2c0738SRiver Riddle } // namespace
472e2c0738SRiver Riddle 
482e2c0738SRiver Riddle std::unique_ptr<OperationPass<ModuleOp>>
createNormalizeMemRefsPass()492e2c0738SRiver Riddle mlir::memref::createNormalizeMemRefsPass() {
502e2c0738SRiver Riddle   return std::make_unique<NormalizeMemRefs>();
512e2c0738SRiver Riddle }
522e2c0738SRiver Riddle 
runOnOperation()532e2c0738SRiver Riddle void NormalizeMemRefs::runOnOperation() {
542e2c0738SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
552e2c0738SRiver Riddle   ModuleOp moduleOp = getOperation();
562e2c0738SRiver Riddle   // We maintain all normalizable FuncOps in a DenseSet. It is initialized
572e2c0738SRiver Riddle   // with all the functions within a module and then functions which are not
582e2c0738SRiver Riddle   // normalizable are removed from this set.
592e2c0738SRiver Riddle   // TODO: Change this to work on FuncLikeOp once there is an operation
602e2c0738SRiver Riddle   // interface for it.
61*58ceae95SRiver Riddle   DenseSet<func::FuncOp> normalizableFuncs;
622e2c0738SRiver Riddle   // Initialize `normalizableFuncs` with all the functions within a module.
63*58ceae95SRiver Riddle   moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
642e2c0738SRiver Riddle 
652e2c0738SRiver Riddle   // Traverse through all the functions applying a filter which determines
662e2c0738SRiver Riddle   // whether that function is normalizable or not. All callers/callees of
672e2c0738SRiver Riddle   // a non-normalizable function will also become non-normalizable even if
682e2c0738SRiver Riddle   // they aren't passing any or specific non-normalizable memrefs. So,
692e2c0738SRiver Riddle   // functions which calls or get called by a non-normalizable becomes non-
702e2c0738SRiver Riddle   // normalizable functions themselves.
71*58ceae95SRiver Riddle   moduleOp.walk([&](func::FuncOp funcOp) {
722e2c0738SRiver Riddle     if (normalizableFuncs.contains(funcOp)) {
732e2c0738SRiver Riddle       if (!areMemRefsNormalizable(funcOp)) {
742e2c0738SRiver Riddle         LLVM_DEBUG(llvm::dbgs()
752e2c0738SRiver Riddle                    << "@" << funcOp.getName()
762e2c0738SRiver Riddle                    << " contains ops that cannot normalize MemRefs\n");
772e2c0738SRiver Riddle         // Since this function is not normalizable, we set all the caller
782e2c0738SRiver Riddle         // functions and the callees of this function as not normalizable.
792e2c0738SRiver Riddle         // TODO: Drop this conservative assumption in the future.
802e2c0738SRiver Riddle         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
812e2c0738SRiver Riddle                                             normalizableFuncs);
822e2c0738SRiver Riddle       }
832e2c0738SRiver Riddle     }
842e2c0738SRiver Riddle   });
852e2c0738SRiver Riddle 
862e2c0738SRiver Riddle   LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
872e2c0738SRiver Riddle                           << " functions\n");
882e2c0738SRiver Riddle   // Those functions which can be normalized are subjected to normalization.
89*58ceae95SRiver Riddle   for (func::FuncOp &funcOp : normalizableFuncs)
902e2c0738SRiver Riddle     normalizeFuncOpMemRefs(funcOp, moduleOp);
912e2c0738SRiver Riddle }
922e2c0738SRiver Riddle 
932e2c0738SRiver Riddle /// Check whether all the uses of oldMemRef are either dereferencing uses or the
942e2c0738SRiver Riddle /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
952e2c0738SRiver Riddle /// are satisfied will the value become a candidate for replacement.
962e2c0738SRiver Riddle /// TODO: Extend this for DimOps.
isMemRefNormalizable(Value::user_range opUsers)972e2c0738SRiver Riddle static bool isMemRefNormalizable(Value::user_range opUsers) {
982e2c0738SRiver Riddle   return llvm::all_of(opUsers, [](Operation *op) {
992e2c0738SRiver Riddle     return op->hasTrait<OpTrait::MemRefsNormalizable>();
1002e2c0738SRiver Riddle   });
1012e2c0738SRiver Riddle }
1022e2c0738SRiver Riddle 
1032e2c0738SRiver Riddle /// Set all the calling functions and the callees of the function as not
1042e2c0738SRiver Riddle /// normalizable.
setCalleesAndCallersNonNormalizable(func::FuncOp funcOp,ModuleOp moduleOp,DenseSet<func::FuncOp> & normalizableFuncs)1052e2c0738SRiver Riddle void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
106*58ceae95SRiver Riddle     func::FuncOp funcOp, ModuleOp moduleOp,
107*58ceae95SRiver Riddle     DenseSet<func::FuncOp> &normalizableFuncs) {
1082e2c0738SRiver Riddle   if (!normalizableFuncs.contains(funcOp))
1092e2c0738SRiver Riddle     return;
1102e2c0738SRiver Riddle 
1112e2c0738SRiver Riddle   LLVM_DEBUG(
1122e2c0738SRiver Riddle       llvm::dbgs() << "@" << funcOp.getName()
1132e2c0738SRiver Riddle                    << " calls or is called by non-normalizable function\n");
1142e2c0738SRiver Riddle   normalizableFuncs.erase(funcOp);
1152e2c0738SRiver Riddle   // Caller of the function.
1162e2c0738SRiver Riddle   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
1172e2c0738SRiver Riddle   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
1182e2c0738SRiver Riddle     // TODO: Extend this for ops that are FunctionOpInterface. This would
1192e2c0738SRiver Riddle     // require creating an OpInterface for FunctionOpInterface ops.
120*58ceae95SRiver Riddle     func::FuncOp parentFuncOp =
121*58ceae95SRiver Riddle         symbolUse.getUser()->getParentOfType<func::FuncOp>();
122*58ceae95SRiver Riddle     for (func::FuncOp &funcOp : normalizableFuncs) {
1232e2c0738SRiver Riddle       if (parentFuncOp == funcOp) {
1242e2c0738SRiver Riddle         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
1252e2c0738SRiver Riddle                                             normalizableFuncs);
1262e2c0738SRiver Riddle         break;
1272e2c0738SRiver Riddle       }
1282e2c0738SRiver Riddle     }
1292e2c0738SRiver Riddle   }
1302e2c0738SRiver Riddle 
1312e2c0738SRiver Riddle   // Functions called by this function.
13223aa5a74SRiver Riddle   funcOp.walk([&](func::CallOp callOp) {
1332e2c0738SRiver Riddle     StringAttr callee = callOp.getCalleeAttr().getAttr();
134*58ceae95SRiver Riddle     for (func::FuncOp &funcOp : normalizableFuncs) {
135*58ceae95SRiver Riddle       // We compare func::FuncOp and callee's name.
1362e2c0738SRiver Riddle       if (callee == funcOp.getNameAttr()) {
1372e2c0738SRiver Riddle         setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
1382e2c0738SRiver Riddle                                             normalizableFuncs);
1392e2c0738SRiver Riddle         break;
1402e2c0738SRiver Riddle       }
1412e2c0738SRiver Riddle     }
1422e2c0738SRiver Riddle   });
1432e2c0738SRiver Riddle }
1442e2c0738SRiver Riddle 
1452e2c0738SRiver Riddle /// Check whether all the uses of AllocOps, CallOps and function arguments of a
1462e2c0738SRiver Riddle /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
1472e2c0738SRiver Riddle /// or ReturnOp. Only if these constraints are satisfied will the function
1482e2c0738SRiver Riddle /// become a candidate for normalization. We follow a conservative approach here
1492e2c0738SRiver Riddle /// wherein even if the non-normalizable memref is not a part of the function's
1502e2c0738SRiver Riddle /// argument or return type, we still label the entire function as
1512e2c0738SRiver Riddle /// non-normalizable. We assume external functions to be normalizable.
areMemRefsNormalizable(func::FuncOp funcOp)152*58ceae95SRiver Riddle bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
1532e2c0738SRiver Riddle   // We assume external functions to be normalizable.
1542e2c0738SRiver Riddle   if (funcOp.isExternal())
1552e2c0738SRiver Riddle     return true;
1562e2c0738SRiver Riddle 
1572e2c0738SRiver Riddle   if (funcOp
1582e2c0738SRiver Riddle           .walk([&](memref::AllocOp allocOp) -> WalkResult {
1592e2c0738SRiver Riddle             Value oldMemRef = allocOp.getResult();
1602e2c0738SRiver Riddle             if (!isMemRefNormalizable(oldMemRef.getUsers()))
1612e2c0738SRiver Riddle               return WalkResult::interrupt();
1622e2c0738SRiver Riddle             return WalkResult::advance();
1632e2c0738SRiver Riddle           })
1642e2c0738SRiver Riddle           .wasInterrupted())
1652e2c0738SRiver Riddle     return false;
1662e2c0738SRiver Riddle 
1672e2c0738SRiver Riddle   if (funcOp
16823aa5a74SRiver Riddle           .walk([&](func::CallOp callOp) -> WalkResult {
1692e2c0738SRiver Riddle             for (unsigned resIndex :
1702e2c0738SRiver Riddle                  llvm::seq<unsigned>(0, callOp.getNumResults())) {
1712e2c0738SRiver Riddle               Value oldMemRef = callOp.getResult(resIndex);
1722e2c0738SRiver Riddle               if (oldMemRef.getType().isa<MemRefType>())
1732e2c0738SRiver Riddle                 if (!isMemRefNormalizable(oldMemRef.getUsers()))
1742e2c0738SRiver Riddle                   return WalkResult::interrupt();
1752e2c0738SRiver Riddle             }
1762e2c0738SRiver Riddle             return WalkResult::advance();
1772e2c0738SRiver Riddle           })
1782e2c0738SRiver Riddle           .wasInterrupted())
1792e2c0738SRiver Riddle     return false;
1802e2c0738SRiver Riddle 
1812e2c0738SRiver Riddle   for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
1822e2c0738SRiver Riddle     BlockArgument oldMemRef = funcOp.getArgument(argIndex);
1832e2c0738SRiver Riddle     if (oldMemRef.getType().isa<MemRefType>())
1842e2c0738SRiver Riddle       if (!isMemRefNormalizable(oldMemRef.getUsers()))
1852e2c0738SRiver Riddle         return false;
1862e2c0738SRiver Riddle   }
1872e2c0738SRiver Riddle 
1882e2c0738SRiver Riddle   return true;
1892e2c0738SRiver Riddle }
1902e2c0738SRiver Riddle 
1912e2c0738SRiver Riddle /// Fetch the updated argument list and result of the function and update the
1922e2c0738SRiver Riddle /// function signature. This updates the function's return type at the caller
1932e2c0738SRiver Riddle /// site and in case the return type is a normalized memref then it updates
1942e2c0738SRiver Riddle /// the calling function's signature.
1952e2c0738SRiver Riddle /// TODO: An update to the calling function signature is required only if the
1962e2c0738SRiver Riddle /// returned value is in turn used in ReturnOp of the calling function.
updateFunctionSignature(func::FuncOp funcOp,ModuleOp moduleOp)197*58ceae95SRiver Riddle void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
1982e2c0738SRiver Riddle                                                ModuleOp moduleOp) {
1994a3460a7SRiver Riddle   FunctionType functionType = funcOp.getFunctionType();
2002e2c0738SRiver Riddle   SmallVector<Type, 4> resultTypes;
2012e2c0738SRiver Riddle   FunctionType newFuncType;
2022e2c0738SRiver Riddle   resultTypes = llvm::to_vector<4>(functionType.getResults());
2032e2c0738SRiver Riddle 
2042e2c0738SRiver Riddle   // External function's signature was already updated in
2052e2c0738SRiver Riddle   // 'normalizeFuncOpMemRefs()'.
2062e2c0738SRiver Riddle   if (!funcOp.isExternal()) {
2072e2c0738SRiver Riddle     SmallVector<Type, 8> argTypes;
2082e2c0738SRiver Riddle     for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
2092e2c0738SRiver Riddle       argTypes.push_back(argEn.value().getType());
2102e2c0738SRiver Riddle 
2112e2c0738SRiver Riddle     // Traverse ReturnOps to check if an update to the return type in the
2122e2c0738SRiver Riddle     // function signature is required.
21323aa5a74SRiver Riddle     funcOp.walk([&](func::ReturnOp returnOp) {
2142e2c0738SRiver Riddle       for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
2152e2c0738SRiver Riddle         Type opType = operandEn.value().getType();
2162e2c0738SRiver Riddle         MemRefType memrefType = opType.dyn_cast<MemRefType>();
2172e2c0738SRiver Riddle         // If type is not memref or if the memref type is same as that in
2182e2c0738SRiver Riddle         // function's return signature then no update is required.
2192e2c0738SRiver Riddle         if (!memrefType || memrefType == resultTypes[operandEn.index()])
2202e2c0738SRiver Riddle           continue;
2212e2c0738SRiver Riddle         // Update function's return type signature.
2222e2c0738SRiver Riddle         // Return type gets normalized either as a result of function argument
2232e2c0738SRiver Riddle         // normalization, AllocOp normalization or an update made at CallOp.
2242e2c0738SRiver Riddle         // There can be many call flows inside a function and an update to a
2252e2c0738SRiver Riddle         // specific ReturnOp has not yet been made. So we check that the result
2262e2c0738SRiver Riddle         // memref type is normalized.
2272e2c0738SRiver Riddle         // TODO: When selective normalization is implemented, handle multiple
2282e2c0738SRiver Riddle         // results case where some are normalized, some aren't.
2292e2c0738SRiver Riddle         if (memrefType.getLayout().isIdentity())
2302e2c0738SRiver Riddle           resultTypes[operandEn.index()] = memrefType;
2312e2c0738SRiver Riddle       }
2322e2c0738SRiver Riddle     });
2332e2c0738SRiver Riddle 
2342e2c0738SRiver Riddle     // We create a new function type and modify the function signature with this
2352e2c0738SRiver Riddle     // new type.
2362e2c0738SRiver Riddle     newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
2372e2c0738SRiver Riddle                                     /*results=*/resultTypes);
2382e2c0738SRiver Riddle   }
2392e2c0738SRiver Riddle 
2402e2c0738SRiver Riddle   // Since we update the function signature, it might affect the result types at
2412e2c0738SRiver Riddle   // the caller site. Since this result might even be used by the caller
2422e2c0738SRiver Riddle   // function in ReturnOps, the caller function's signature will also change.
2432e2c0738SRiver Riddle   // Hence we record the caller function in 'funcOpsToUpdate' to update their
2442e2c0738SRiver Riddle   // signature as well.
245*58ceae95SRiver Riddle   llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
2462e2c0738SRiver Riddle   // We iterate over all symbolic uses of the function and update the return
2472e2c0738SRiver Riddle   // type at the caller site.
2482e2c0738SRiver Riddle   Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
2492e2c0738SRiver Riddle   for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
2502e2c0738SRiver Riddle     Operation *userOp = symbolUse.getUser();
2512e2c0738SRiver Riddle     OpBuilder builder(userOp);
2522e2c0738SRiver Riddle     // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
2532e2c0738SRiver Riddle     // that the non-CallOp has no memrefs to be replaced.
2542e2c0738SRiver Riddle     // TODO: Handle cases where a non-CallOp symbol use of a function deals with
2552e2c0738SRiver Riddle     // memrefs.
25623aa5a74SRiver Riddle     auto callOp = dyn_cast<func::CallOp>(userOp);
2572e2c0738SRiver Riddle     if (!callOp)
2582e2c0738SRiver Riddle       continue;
2592e2c0738SRiver Riddle     Operation *newCallOp =
26023aa5a74SRiver Riddle         builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
2612e2c0738SRiver Riddle                                      resultTypes, userOp->getOperands());
2622e2c0738SRiver Riddle     bool replacingMemRefUsesFailed = false;
2632e2c0738SRiver Riddle     bool returnTypeChanged = false;
2642e2c0738SRiver Riddle     for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
2652e2c0738SRiver Riddle       OpResult oldResult = userOp->getResult(resIndex);
2662e2c0738SRiver Riddle       OpResult newResult = newCallOp->getResult(resIndex);
2672e2c0738SRiver Riddle       // This condition ensures that if the result is not of type memref or if
2682e2c0738SRiver Riddle       // the resulting memref was already having a trivial map layout then we
2692e2c0738SRiver Riddle       // need not perform any use replacement here.
2702e2c0738SRiver Riddle       if (oldResult.getType() == newResult.getType())
2712e2c0738SRiver Riddle         continue;
2722e2c0738SRiver Riddle       AffineMap layoutMap =
2732e2c0738SRiver Riddle           oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
2742e2c0738SRiver Riddle       if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
2752e2c0738SRiver Riddle                                           /*extraIndices=*/{},
2762e2c0738SRiver Riddle                                           /*indexRemap=*/layoutMap,
2772e2c0738SRiver Riddle                                           /*extraOperands=*/{},
2782e2c0738SRiver Riddle                                           /*symbolOperands=*/{},
2792e2c0738SRiver Riddle                                           /*domOpFilter=*/nullptr,
2802e2c0738SRiver Riddle                                           /*postDomOpFilter=*/nullptr,
2812e2c0738SRiver Riddle                                           /*allowNonDereferencingOps=*/true,
2822e2c0738SRiver Riddle                                           /*replaceInDeallocOp=*/true))) {
2832e2c0738SRiver Riddle         // If it failed (due to escapes for example), bail out.
2842e2c0738SRiver Riddle         // It should never hit this part of the code because it is called by
2852e2c0738SRiver Riddle         // only those functions which are normalizable.
2862e2c0738SRiver Riddle         newCallOp->erase();
2872e2c0738SRiver Riddle         replacingMemRefUsesFailed = true;
2882e2c0738SRiver Riddle         break;
2892e2c0738SRiver Riddle       }
2902e2c0738SRiver Riddle       returnTypeChanged = true;
2912e2c0738SRiver Riddle     }
2922e2c0738SRiver Riddle     if (replacingMemRefUsesFailed)
2932e2c0738SRiver Riddle       continue;
2942e2c0738SRiver Riddle     // Replace all uses for other non-memref result types.
2952e2c0738SRiver Riddle     userOp->replaceAllUsesWith(newCallOp);
2962e2c0738SRiver Riddle     userOp->erase();
2972e2c0738SRiver Riddle     if (returnTypeChanged) {
2982e2c0738SRiver Riddle       // Since the return type changed it might lead to a change in function's
2992e2c0738SRiver Riddle       // signature.
3002e2c0738SRiver Riddle       // TODO: If funcOp doesn't return any memref type then no need to update
3012e2c0738SRiver Riddle       // signature.
3022e2c0738SRiver Riddle       // TODO: Further optimization - Check if the memref is indeed part of
3032e2c0738SRiver Riddle       // ReturnOp at the parentFuncOp and only then updation of signature is
3042e2c0738SRiver Riddle       // required.
3052e2c0738SRiver Riddle       // TODO: Extend this for ops that are FunctionOpInterface. This would
3062e2c0738SRiver Riddle       // require creating an OpInterface for FunctionOpInterface ops.
307*58ceae95SRiver Riddle       func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>();
3082e2c0738SRiver Riddle       funcOpsToUpdate.insert(parentFuncOp);
3092e2c0738SRiver Riddle     }
3102e2c0738SRiver Riddle   }
3112e2c0738SRiver Riddle   // Because external function's signature is already updated in
3122e2c0738SRiver Riddle   // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
3132e2c0738SRiver Riddle   if (!funcOp.isExternal())
3142e2c0738SRiver Riddle     funcOp.setType(newFuncType);
3152e2c0738SRiver Riddle 
3162e2c0738SRiver Riddle   // Updating the signature type of those functions which call the current
3172e2c0738SRiver Riddle   // function. Only if the return type of the current function has a normalized
3182e2c0738SRiver Riddle   // memref will the caller function become a candidate for signature update.
319*58ceae95SRiver Riddle   for (func::FuncOp parentFuncOp : funcOpsToUpdate)
3202e2c0738SRiver Riddle     updateFunctionSignature(parentFuncOp, moduleOp);
3212e2c0738SRiver Riddle }
3222e2c0738SRiver Riddle 
3232e2c0738SRiver Riddle /// Normalizes the memrefs within a function which includes those arising as a
3242e2c0738SRiver Riddle /// result of AllocOps, CallOps and function's argument. The ModuleOp argument
3252e2c0738SRiver Riddle /// is used to help update function's signature after normalization.
normalizeFuncOpMemRefs(func::FuncOp funcOp,ModuleOp moduleOp)326*58ceae95SRiver Riddle void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
3272e2c0738SRiver Riddle                                               ModuleOp moduleOp) {
3282e2c0738SRiver Riddle   // Turn memrefs' non-identity layouts maps into ones with identity. Collect
3292e2c0738SRiver Riddle   // alloc ops first and then process since normalizeMemRef replaces/erases ops
3302e2c0738SRiver Riddle   // during memref rewriting.
3312e2c0738SRiver Riddle   SmallVector<memref::AllocOp, 4> allocOps;
3322e2c0738SRiver Riddle   funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
3332e2c0738SRiver Riddle   for (memref::AllocOp allocOp : allocOps)
3342e2c0738SRiver Riddle     (void)normalizeMemRef(&allocOp);
3352e2c0738SRiver Riddle 
3362e2c0738SRiver Riddle   // We use this OpBuilder to create new memref layout later.
3372e2c0738SRiver Riddle   OpBuilder b(funcOp);
3382e2c0738SRiver Riddle 
3394a3460a7SRiver Riddle   FunctionType functionType = funcOp.getFunctionType();
3402e2c0738SRiver Riddle   SmallVector<Location> functionArgLocs(llvm::map_range(
3412e2c0738SRiver Riddle       funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
3422e2c0738SRiver Riddle   SmallVector<Type, 8> inputTypes;
3432e2c0738SRiver Riddle   // Walk over each argument of a function to perform memref normalization (if
3442e2c0738SRiver Riddle   for (unsigned argIndex :
3452e2c0738SRiver Riddle        llvm::seq<unsigned>(0, functionType.getNumInputs())) {
3462e2c0738SRiver Riddle     Type argType = functionType.getInput(argIndex);
3472e2c0738SRiver Riddle     MemRefType memrefType = argType.dyn_cast<MemRefType>();
3482e2c0738SRiver Riddle     // Check whether argument is of MemRef type. Any other argument type can
3492e2c0738SRiver Riddle     // simply be part of the final function signature.
3502e2c0738SRiver Riddle     if (!memrefType) {
3512e2c0738SRiver Riddle       inputTypes.push_back(argType);
3522e2c0738SRiver Riddle       continue;
3532e2c0738SRiver Riddle     }
3542e2c0738SRiver Riddle     // Fetch a new memref type after normalizing the old memref to have an
3552e2c0738SRiver Riddle     // identity map layout.
3562e2c0738SRiver Riddle     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
3572e2c0738SRiver Riddle                                                    /*numSymbolicOperands=*/0);
3582e2c0738SRiver Riddle     if (newMemRefType == memrefType || funcOp.isExternal()) {
3592e2c0738SRiver Riddle       // Either memrefType already had an identity map or the map couldn't be
3602e2c0738SRiver Riddle       // transformed to an identity map.
3612e2c0738SRiver Riddle       inputTypes.push_back(newMemRefType);
3622e2c0738SRiver Riddle       continue;
3632e2c0738SRiver Riddle     }
3642e2c0738SRiver Riddle 
3652e2c0738SRiver Riddle     // Insert a new temporary argument with the new memref type.
3662e2c0738SRiver Riddle     BlockArgument newMemRef = funcOp.front().insertArgument(
3672e2c0738SRiver Riddle         argIndex, newMemRefType, functionArgLocs[argIndex]);
3682e2c0738SRiver Riddle     BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
3692e2c0738SRiver Riddle     AffineMap layoutMap = memrefType.getLayout().getAffineMap();
3702e2c0738SRiver Riddle     // Replace all uses of the old memref.
3712e2c0738SRiver Riddle     if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
3722e2c0738SRiver Riddle                                         /*extraIndices=*/{},
3732e2c0738SRiver Riddle                                         /*indexRemap=*/layoutMap,
3742e2c0738SRiver Riddle                                         /*extraOperands=*/{},
3752e2c0738SRiver Riddle                                         /*symbolOperands=*/{},
3762e2c0738SRiver Riddle                                         /*domOpFilter=*/nullptr,
3772e2c0738SRiver Riddle                                         /*postDomOpFilter=*/nullptr,
3782e2c0738SRiver Riddle                                         /*allowNonDereferencingOps=*/true,
3792e2c0738SRiver Riddle                                         /*replaceInDeallocOp=*/true))) {
3802e2c0738SRiver Riddle       // If it failed (due to escapes for example), bail out. Removing the
3812e2c0738SRiver Riddle       // temporary argument inserted previously.
3822e2c0738SRiver Riddle       funcOp.front().eraseArgument(argIndex);
3832e2c0738SRiver Riddle       continue;
3842e2c0738SRiver Riddle     }
3852e2c0738SRiver Riddle 
3862e2c0738SRiver Riddle     // All uses for the argument with old memref type were replaced
3872e2c0738SRiver Riddle     // successfully. So we remove the old argument now.
3882e2c0738SRiver Riddle     funcOp.front().eraseArgument(argIndex + 1);
3892e2c0738SRiver Riddle   }
3902e2c0738SRiver Riddle 
3912e2c0738SRiver Riddle   // Walk over normalizable operations to normalize memrefs of the operation
3922e2c0738SRiver Riddle   // results. When `op` has memrefs with affine map in the operation results,
3932e2c0738SRiver Riddle   // new operation containin normalized memrefs is created. Then, the memrefs
3942e2c0738SRiver Riddle   // are replaced. `CallOp` is skipped here because it is handled in
3952e2c0738SRiver Riddle   // `updateFunctionSignature()`.
3962e2c0738SRiver Riddle   funcOp.walk([&](Operation *op) {
3972e2c0738SRiver Riddle     if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
39823aa5a74SRiver Riddle         op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
39923aa5a74SRiver Riddle         !funcOp.isExternal()) {
4002e2c0738SRiver Riddle       // Create newOp containing normalized memref in the operation result.
4012e2c0738SRiver Riddle       Operation *newOp = createOpResultsNormalized(funcOp, op);
4022e2c0738SRiver Riddle       // When all of the operation results have no memrefs or memrefs without
4032e2c0738SRiver Riddle       // affine map, `newOp` is the same with `op` and following process is
4042e2c0738SRiver Riddle       // skipped.
4052e2c0738SRiver Riddle       if (op != newOp) {
4062e2c0738SRiver Riddle         bool replacingMemRefUsesFailed = false;
4072e2c0738SRiver Riddle         for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
4082e2c0738SRiver Riddle           // Replace all uses of the old memrefs.
4092e2c0738SRiver Riddle           Value oldMemRef = op->getResult(resIndex);
4102e2c0738SRiver Riddle           Value newMemRef = newOp->getResult(resIndex);
4112e2c0738SRiver Riddle           MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
4122e2c0738SRiver Riddle           // Check whether the operation result is MemRef type.
4132e2c0738SRiver Riddle           if (!oldMemRefType)
4142e2c0738SRiver Riddle             continue;
4152e2c0738SRiver Riddle           MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
4162e2c0738SRiver Riddle           if (oldMemRefType == newMemRefType)
4172e2c0738SRiver Riddle             continue;
4182e2c0738SRiver Riddle           // TODO: Assume single layout map. Multiple maps not supported.
4192e2c0738SRiver Riddle           AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
4202e2c0738SRiver Riddle           if (failed(replaceAllMemRefUsesWith(oldMemRef,
4212e2c0738SRiver Riddle                                               /*newMemRef=*/newMemRef,
4222e2c0738SRiver Riddle                                               /*extraIndices=*/{},
4232e2c0738SRiver Riddle                                               /*indexRemap=*/layoutMap,
4242e2c0738SRiver Riddle                                               /*extraOperands=*/{},
4252e2c0738SRiver Riddle                                               /*symbolOperands=*/{},
4262e2c0738SRiver Riddle                                               /*domOpFilter=*/nullptr,
4272e2c0738SRiver Riddle                                               /*postDomOpFilter=*/nullptr,
4282e2c0738SRiver Riddle                                               /*allowNonDereferencingOps=*/true,
4292e2c0738SRiver Riddle                                               /*replaceInDeallocOp=*/true))) {
4302e2c0738SRiver Riddle             newOp->erase();
4312e2c0738SRiver Riddle             replacingMemRefUsesFailed = true;
4322e2c0738SRiver Riddle             continue;
4332e2c0738SRiver Riddle           }
4342e2c0738SRiver Riddle         }
4352e2c0738SRiver Riddle         if (!replacingMemRefUsesFailed) {
4362e2c0738SRiver Riddle           // Replace other ops with new op and delete the old op when the
4372e2c0738SRiver Riddle           // replacement succeeded.
4382e2c0738SRiver Riddle           op->replaceAllUsesWith(newOp);
4392e2c0738SRiver Riddle           op->erase();
4402e2c0738SRiver Riddle         }
4412e2c0738SRiver Riddle       }
4422e2c0738SRiver Riddle     }
4432e2c0738SRiver Riddle   });
4442e2c0738SRiver Riddle 
4452e2c0738SRiver Riddle   // In a normal function, memrefs in the return type signature gets normalized
4462e2c0738SRiver Riddle   // as a result of normalization of functions arguments, AllocOps or CallOps'
4472e2c0738SRiver Riddle   // result types. Since an external function doesn't have a body, memrefs in
4482e2c0738SRiver Riddle   // the return type signature can only get normalized by iterating over the
4492e2c0738SRiver Riddle   // individual return types.
4502e2c0738SRiver Riddle   if (funcOp.isExternal()) {
4512e2c0738SRiver Riddle     SmallVector<Type, 4> resultTypes;
4522e2c0738SRiver Riddle     for (unsigned resIndex :
4532e2c0738SRiver Riddle          llvm::seq<unsigned>(0, functionType.getNumResults())) {
4542e2c0738SRiver Riddle       Type resType = functionType.getResult(resIndex);
4552e2c0738SRiver Riddle       MemRefType memrefType = resType.dyn_cast<MemRefType>();
4562e2c0738SRiver Riddle       // Check whether result is of MemRef type. Any other argument type can
4572e2c0738SRiver Riddle       // simply be part of the final function signature.
4582e2c0738SRiver Riddle       if (!memrefType) {
4592e2c0738SRiver Riddle         resultTypes.push_back(resType);
4602e2c0738SRiver Riddle         continue;
4612e2c0738SRiver Riddle       }
4622e2c0738SRiver Riddle       // Computing a new memref type after normalizing the old memref to have an
4632e2c0738SRiver Riddle       // identity map layout.
4642e2c0738SRiver Riddle       MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
4652e2c0738SRiver Riddle                                                      /*numSymbolicOperands=*/0);
4662e2c0738SRiver Riddle       resultTypes.push_back(newMemRefType);
4672e2c0738SRiver Riddle     }
4682e2c0738SRiver Riddle 
4692e2c0738SRiver Riddle     FunctionType newFuncType =
4702e2c0738SRiver Riddle         FunctionType::get(&getContext(), /*inputs=*/inputTypes,
4712e2c0738SRiver Riddle                           /*results=*/resultTypes);
4722e2c0738SRiver Riddle     // Setting the new function signature for this external function.
4732e2c0738SRiver Riddle     funcOp.setType(newFuncType);
4742e2c0738SRiver Riddle   }
4752e2c0738SRiver Riddle   updateFunctionSignature(funcOp, moduleOp);
4762e2c0738SRiver Riddle }
4772e2c0738SRiver Riddle 
4782e2c0738SRiver Riddle /// Create an operation containing normalized memrefs in the operation results.
4792e2c0738SRiver Riddle /// When the results of `oldOp` have memrefs with affine map, the memrefs are
4802e2c0738SRiver Riddle /// normalized, and new operation containing them in the operation results is
4812e2c0738SRiver Riddle /// returned. If all of the results of `oldOp` have no memrefs or memrefs
4822e2c0738SRiver Riddle /// without affine map, `oldOp` is returned without modification.
createOpResultsNormalized(func::FuncOp funcOp,Operation * oldOp)483*58ceae95SRiver Riddle Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
4842e2c0738SRiver Riddle                                                        Operation *oldOp) {
4852e2c0738SRiver Riddle   // Prepare OperationState to create newOp containing normalized memref in
4862e2c0738SRiver Riddle   // the operation results.
4872e2c0738SRiver Riddle   OperationState result(oldOp->getLoc(), oldOp->getName());
4882e2c0738SRiver Riddle   result.addOperands(oldOp->getOperands());
4892e2c0738SRiver Riddle   result.addAttributes(oldOp->getAttrs());
4902e2c0738SRiver Riddle   // Add normalized MemRefType to the OperationState.
4912e2c0738SRiver Riddle   SmallVector<Type, 4> resultTypes;
4922e2c0738SRiver Riddle   OpBuilder b(funcOp);
4932e2c0738SRiver Riddle   bool resultTypeNormalized = false;
4942e2c0738SRiver Riddle   for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
4952e2c0738SRiver Riddle     auto resultType = oldOp->getResult(resIndex).getType();
4962e2c0738SRiver Riddle     MemRefType memrefType = resultType.dyn_cast<MemRefType>();
4972e2c0738SRiver Riddle     // Check whether the operation result is MemRef type.
4982e2c0738SRiver Riddle     if (!memrefType) {
4992e2c0738SRiver Riddle       resultTypes.push_back(resultType);
5002e2c0738SRiver Riddle       continue;
5012e2c0738SRiver Riddle     }
5022e2c0738SRiver Riddle     // Fetch a new memref type after normalizing the old memref.
5032e2c0738SRiver Riddle     MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
5042e2c0738SRiver Riddle                                                    /*numSymbolicOperands=*/0);
5052e2c0738SRiver Riddle     if (newMemRefType == memrefType) {
5062e2c0738SRiver Riddle       // Either memrefType already had an identity map or the map couldn't
5072e2c0738SRiver Riddle       // be transformed to an identity map.
5082e2c0738SRiver Riddle       resultTypes.push_back(memrefType);
5092e2c0738SRiver Riddle       continue;
5102e2c0738SRiver Riddle     }
5112e2c0738SRiver Riddle     resultTypes.push_back(newMemRefType);
5122e2c0738SRiver Riddle     resultTypeNormalized = true;
5132e2c0738SRiver Riddle   }
5142e2c0738SRiver Riddle   result.addTypes(resultTypes);
5152e2c0738SRiver Riddle   // When all of the results of `oldOp` have no memrefs or memrefs without
5162e2c0738SRiver Riddle   // affine map, `oldOp` is returned without modification.
5172e2c0738SRiver Riddle   if (resultTypeNormalized) {
5182e2c0738SRiver Riddle     OpBuilder bb(oldOp);
5192e2c0738SRiver Riddle     for (auto &oldRegion : oldOp->getRegions()) {
5202e2c0738SRiver Riddle       Region *newRegion = result.addRegion();
5212e2c0738SRiver Riddle       newRegion->takeBody(oldRegion);
5222e2c0738SRiver Riddle     }
52314ecafd0SChia-hung Duan     return bb.create(result);
5242e2c0738SRiver Riddle   }
5252e2c0738SRiver Riddle   return oldOp;
5262e2c0738SRiver Riddle }
527