12e2c0738SRiver Riddle //===- NormalizeMemRefs.cpp -----------------------------------------------===//
22e2c0738SRiver Riddle //
32e2c0738SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42e2c0738SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
52e2c0738SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62e2c0738SRiver Riddle //
72e2c0738SRiver Riddle //===----------------------------------------------------------------------===//
82e2c0738SRiver Riddle //
92e2c0738SRiver Riddle // This file implements an interprocedural pass to normalize memrefs to have
102e2c0738SRiver Riddle // identity layout maps.
112e2c0738SRiver Riddle //
122e2c0738SRiver Riddle //===----------------------------------------------------------------------===//
132e2c0738SRiver Riddle
142e2c0738SRiver Riddle #include "PassDetail.h"
152e2c0738SRiver Riddle #include "mlir/Dialect/Affine/IR/AffineOps.h"
16a70aa7bbSRiver Riddle #include "mlir/Dialect/Affine/Utils.h"
171f971e23SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
182e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/IR/MemRef.h"
192e2c0738SRiver Riddle #include "mlir/Dialect/MemRef/Transforms/Passes.h"
202e2c0738SRiver Riddle #include "llvm/ADT/SmallSet.h"
212e2c0738SRiver Riddle #include "llvm/Support/Debug.h"
222e2c0738SRiver Riddle
232e2c0738SRiver Riddle #define DEBUG_TYPE "normalize-memrefs"
242e2c0738SRiver Riddle
252e2c0738SRiver Riddle using namespace mlir;
262e2c0738SRiver Riddle
272e2c0738SRiver Riddle namespace {
282e2c0738SRiver Riddle
292e2c0738SRiver Riddle /// All memrefs passed across functions with non-trivial layout maps are
302e2c0738SRiver Riddle /// converted to ones with trivial identity layout ones.
312e2c0738SRiver Riddle /// If all the memref types/uses in a function are normalizable, we treat
322e2c0738SRiver Riddle /// such functions as normalizable. Also, if a normalizable function is known
332e2c0738SRiver Riddle /// to call a non-normalizable function, we treat that function as
342e2c0738SRiver Riddle /// non-normalizable as well. We assume external functions to be normalizable.
352e2c0738SRiver Riddle struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
362e2c0738SRiver Riddle void runOnOperation() override;
37*58ceae95SRiver Riddle void normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp);
38*58ceae95SRiver Riddle bool areMemRefsNormalizable(func::FuncOp funcOp);
39*58ceae95SRiver Riddle void updateFunctionSignature(func::FuncOp funcOp, ModuleOp moduleOp);
40*58ceae95SRiver Riddle void setCalleesAndCallersNonNormalizable(
41*58ceae95SRiver Riddle func::FuncOp funcOp, ModuleOp moduleOp,
42*58ceae95SRiver Riddle DenseSet<func::FuncOp> &normalizableFuncs);
43*58ceae95SRiver Riddle Operation *createOpResultsNormalized(func::FuncOp funcOp, Operation *oldOp);
442e2c0738SRiver Riddle };
452e2c0738SRiver Riddle
462e2c0738SRiver Riddle } // namespace
472e2c0738SRiver Riddle
482e2c0738SRiver Riddle std::unique_ptr<OperationPass<ModuleOp>>
createNormalizeMemRefsPass()492e2c0738SRiver Riddle mlir::memref::createNormalizeMemRefsPass() {
502e2c0738SRiver Riddle return std::make_unique<NormalizeMemRefs>();
512e2c0738SRiver Riddle }
522e2c0738SRiver Riddle
runOnOperation()532e2c0738SRiver Riddle void NormalizeMemRefs::runOnOperation() {
542e2c0738SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
552e2c0738SRiver Riddle ModuleOp moduleOp = getOperation();
562e2c0738SRiver Riddle // We maintain all normalizable FuncOps in a DenseSet. It is initialized
572e2c0738SRiver Riddle // with all the functions within a module and then functions which are not
582e2c0738SRiver Riddle // normalizable are removed from this set.
592e2c0738SRiver Riddle // TODO: Change this to work on FuncLikeOp once there is an operation
602e2c0738SRiver Riddle // interface for it.
61*58ceae95SRiver Riddle DenseSet<func::FuncOp> normalizableFuncs;
622e2c0738SRiver Riddle // Initialize `normalizableFuncs` with all the functions within a module.
63*58ceae95SRiver Riddle moduleOp.walk([&](func::FuncOp funcOp) { normalizableFuncs.insert(funcOp); });
642e2c0738SRiver Riddle
652e2c0738SRiver Riddle // Traverse through all the functions applying a filter which determines
662e2c0738SRiver Riddle // whether that function is normalizable or not. All callers/callees of
672e2c0738SRiver Riddle // a non-normalizable function will also become non-normalizable even if
682e2c0738SRiver Riddle // they aren't passing any or specific non-normalizable memrefs. So,
692e2c0738SRiver Riddle // functions which calls or get called by a non-normalizable becomes non-
702e2c0738SRiver Riddle // normalizable functions themselves.
71*58ceae95SRiver Riddle moduleOp.walk([&](func::FuncOp funcOp) {
722e2c0738SRiver Riddle if (normalizableFuncs.contains(funcOp)) {
732e2c0738SRiver Riddle if (!areMemRefsNormalizable(funcOp)) {
742e2c0738SRiver Riddle LLVM_DEBUG(llvm::dbgs()
752e2c0738SRiver Riddle << "@" << funcOp.getName()
762e2c0738SRiver Riddle << " contains ops that cannot normalize MemRefs\n");
772e2c0738SRiver Riddle // Since this function is not normalizable, we set all the caller
782e2c0738SRiver Riddle // functions and the callees of this function as not normalizable.
792e2c0738SRiver Riddle // TODO: Drop this conservative assumption in the future.
802e2c0738SRiver Riddle setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
812e2c0738SRiver Riddle normalizableFuncs);
822e2c0738SRiver Riddle }
832e2c0738SRiver Riddle }
842e2c0738SRiver Riddle });
852e2c0738SRiver Riddle
862e2c0738SRiver Riddle LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
872e2c0738SRiver Riddle << " functions\n");
882e2c0738SRiver Riddle // Those functions which can be normalized are subjected to normalization.
89*58ceae95SRiver Riddle for (func::FuncOp &funcOp : normalizableFuncs)
902e2c0738SRiver Riddle normalizeFuncOpMemRefs(funcOp, moduleOp);
912e2c0738SRiver Riddle }
922e2c0738SRiver Riddle
932e2c0738SRiver Riddle /// Check whether all the uses of oldMemRef are either dereferencing uses or the
942e2c0738SRiver Riddle /// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
952e2c0738SRiver Riddle /// are satisfied will the value become a candidate for replacement.
962e2c0738SRiver Riddle /// TODO: Extend this for DimOps.
isMemRefNormalizable(Value::user_range opUsers)972e2c0738SRiver Riddle static bool isMemRefNormalizable(Value::user_range opUsers) {
982e2c0738SRiver Riddle return llvm::all_of(opUsers, [](Operation *op) {
992e2c0738SRiver Riddle return op->hasTrait<OpTrait::MemRefsNormalizable>();
1002e2c0738SRiver Riddle });
1012e2c0738SRiver Riddle }
1022e2c0738SRiver Riddle
1032e2c0738SRiver Riddle /// Set all the calling functions and the callees of the function as not
1042e2c0738SRiver Riddle /// normalizable.
setCalleesAndCallersNonNormalizable(func::FuncOp funcOp,ModuleOp moduleOp,DenseSet<func::FuncOp> & normalizableFuncs)1052e2c0738SRiver Riddle void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
106*58ceae95SRiver Riddle func::FuncOp funcOp, ModuleOp moduleOp,
107*58ceae95SRiver Riddle DenseSet<func::FuncOp> &normalizableFuncs) {
1082e2c0738SRiver Riddle if (!normalizableFuncs.contains(funcOp))
1092e2c0738SRiver Riddle return;
1102e2c0738SRiver Riddle
1112e2c0738SRiver Riddle LLVM_DEBUG(
1122e2c0738SRiver Riddle llvm::dbgs() << "@" << funcOp.getName()
1132e2c0738SRiver Riddle << " calls or is called by non-normalizable function\n");
1142e2c0738SRiver Riddle normalizableFuncs.erase(funcOp);
1152e2c0738SRiver Riddle // Caller of the function.
1162e2c0738SRiver Riddle Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
1172e2c0738SRiver Riddle for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
1182e2c0738SRiver Riddle // TODO: Extend this for ops that are FunctionOpInterface. This would
1192e2c0738SRiver Riddle // require creating an OpInterface for FunctionOpInterface ops.
120*58ceae95SRiver Riddle func::FuncOp parentFuncOp =
121*58ceae95SRiver Riddle symbolUse.getUser()->getParentOfType<func::FuncOp>();
122*58ceae95SRiver Riddle for (func::FuncOp &funcOp : normalizableFuncs) {
1232e2c0738SRiver Riddle if (parentFuncOp == funcOp) {
1242e2c0738SRiver Riddle setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
1252e2c0738SRiver Riddle normalizableFuncs);
1262e2c0738SRiver Riddle break;
1272e2c0738SRiver Riddle }
1282e2c0738SRiver Riddle }
1292e2c0738SRiver Riddle }
1302e2c0738SRiver Riddle
1312e2c0738SRiver Riddle // Functions called by this function.
13223aa5a74SRiver Riddle funcOp.walk([&](func::CallOp callOp) {
1332e2c0738SRiver Riddle StringAttr callee = callOp.getCalleeAttr().getAttr();
134*58ceae95SRiver Riddle for (func::FuncOp &funcOp : normalizableFuncs) {
135*58ceae95SRiver Riddle // We compare func::FuncOp and callee's name.
1362e2c0738SRiver Riddle if (callee == funcOp.getNameAttr()) {
1372e2c0738SRiver Riddle setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
1382e2c0738SRiver Riddle normalizableFuncs);
1392e2c0738SRiver Riddle break;
1402e2c0738SRiver Riddle }
1412e2c0738SRiver Riddle }
1422e2c0738SRiver Riddle });
1432e2c0738SRiver Riddle }
1442e2c0738SRiver Riddle
1452e2c0738SRiver Riddle /// Check whether all the uses of AllocOps, CallOps and function arguments of a
1462e2c0738SRiver Riddle /// function are either of dereferencing type or are uses in: DeallocOp, CallOp
1472e2c0738SRiver Riddle /// or ReturnOp. Only if these constraints are satisfied will the function
1482e2c0738SRiver Riddle /// become a candidate for normalization. We follow a conservative approach here
1492e2c0738SRiver Riddle /// wherein even if the non-normalizable memref is not a part of the function's
1502e2c0738SRiver Riddle /// argument or return type, we still label the entire function as
1512e2c0738SRiver Riddle /// non-normalizable. We assume external functions to be normalizable.
areMemRefsNormalizable(func::FuncOp funcOp)152*58ceae95SRiver Riddle bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) {
1532e2c0738SRiver Riddle // We assume external functions to be normalizable.
1542e2c0738SRiver Riddle if (funcOp.isExternal())
1552e2c0738SRiver Riddle return true;
1562e2c0738SRiver Riddle
1572e2c0738SRiver Riddle if (funcOp
1582e2c0738SRiver Riddle .walk([&](memref::AllocOp allocOp) -> WalkResult {
1592e2c0738SRiver Riddle Value oldMemRef = allocOp.getResult();
1602e2c0738SRiver Riddle if (!isMemRefNormalizable(oldMemRef.getUsers()))
1612e2c0738SRiver Riddle return WalkResult::interrupt();
1622e2c0738SRiver Riddle return WalkResult::advance();
1632e2c0738SRiver Riddle })
1642e2c0738SRiver Riddle .wasInterrupted())
1652e2c0738SRiver Riddle return false;
1662e2c0738SRiver Riddle
1672e2c0738SRiver Riddle if (funcOp
16823aa5a74SRiver Riddle .walk([&](func::CallOp callOp) -> WalkResult {
1692e2c0738SRiver Riddle for (unsigned resIndex :
1702e2c0738SRiver Riddle llvm::seq<unsigned>(0, callOp.getNumResults())) {
1712e2c0738SRiver Riddle Value oldMemRef = callOp.getResult(resIndex);
1722e2c0738SRiver Riddle if (oldMemRef.getType().isa<MemRefType>())
1732e2c0738SRiver Riddle if (!isMemRefNormalizable(oldMemRef.getUsers()))
1742e2c0738SRiver Riddle return WalkResult::interrupt();
1752e2c0738SRiver Riddle }
1762e2c0738SRiver Riddle return WalkResult::advance();
1772e2c0738SRiver Riddle })
1782e2c0738SRiver Riddle .wasInterrupted())
1792e2c0738SRiver Riddle return false;
1802e2c0738SRiver Riddle
1812e2c0738SRiver Riddle for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
1822e2c0738SRiver Riddle BlockArgument oldMemRef = funcOp.getArgument(argIndex);
1832e2c0738SRiver Riddle if (oldMemRef.getType().isa<MemRefType>())
1842e2c0738SRiver Riddle if (!isMemRefNormalizable(oldMemRef.getUsers()))
1852e2c0738SRiver Riddle return false;
1862e2c0738SRiver Riddle }
1872e2c0738SRiver Riddle
1882e2c0738SRiver Riddle return true;
1892e2c0738SRiver Riddle }
1902e2c0738SRiver Riddle
1912e2c0738SRiver Riddle /// Fetch the updated argument list and result of the function and update the
1922e2c0738SRiver Riddle /// function signature. This updates the function's return type at the caller
1932e2c0738SRiver Riddle /// site and in case the return type is a normalized memref then it updates
1942e2c0738SRiver Riddle /// the calling function's signature.
1952e2c0738SRiver Riddle /// TODO: An update to the calling function signature is required only if the
1962e2c0738SRiver Riddle /// returned value is in turn used in ReturnOp of the calling function.
updateFunctionSignature(func::FuncOp funcOp,ModuleOp moduleOp)197*58ceae95SRiver Riddle void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
1982e2c0738SRiver Riddle ModuleOp moduleOp) {
1994a3460a7SRiver Riddle FunctionType functionType = funcOp.getFunctionType();
2002e2c0738SRiver Riddle SmallVector<Type, 4> resultTypes;
2012e2c0738SRiver Riddle FunctionType newFuncType;
2022e2c0738SRiver Riddle resultTypes = llvm::to_vector<4>(functionType.getResults());
2032e2c0738SRiver Riddle
2042e2c0738SRiver Riddle // External function's signature was already updated in
2052e2c0738SRiver Riddle // 'normalizeFuncOpMemRefs()'.
2062e2c0738SRiver Riddle if (!funcOp.isExternal()) {
2072e2c0738SRiver Riddle SmallVector<Type, 8> argTypes;
2082e2c0738SRiver Riddle for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
2092e2c0738SRiver Riddle argTypes.push_back(argEn.value().getType());
2102e2c0738SRiver Riddle
2112e2c0738SRiver Riddle // Traverse ReturnOps to check if an update to the return type in the
2122e2c0738SRiver Riddle // function signature is required.
21323aa5a74SRiver Riddle funcOp.walk([&](func::ReturnOp returnOp) {
2142e2c0738SRiver Riddle for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
2152e2c0738SRiver Riddle Type opType = operandEn.value().getType();
2162e2c0738SRiver Riddle MemRefType memrefType = opType.dyn_cast<MemRefType>();
2172e2c0738SRiver Riddle // If type is not memref or if the memref type is same as that in
2182e2c0738SRiver Riddle // function's return signature then no update is required.
2192e2c0738SRiver Riddle if (!memrefType || memrefType == resultTypes[operandEn.index()])
2202e2c0738SRiver Riddle continue;
2212e2c0738SRiver Riddle // Update function's return type signature.
2222e2c0738SRiver Riddle // Return type gets normalized either as a result of function argument
2232e2c0738SRiver Riddle // normalization, AllocOp normalization or an update made at CallOp.
2242e2c0738SRiver Riddle // There can be many call flows inside a function and an update to a
2252e2c0738SRiver Riddle // specific ReturnOp has not yet been made. So we check that the result
2262e2c0738SRiver Riddle // memref type is normalized.
2272e2c0738SRiver Riddle // TODO: When selective normalization is implemented, handle multiple
2282e2c0738SRiver Riddle // results case where some are normalized, some aren't.
2292e2c0738SRiver Riddle if (memrefType.getLayout().isIdentity())
2302e2c0738SRiver Riddle resultTypes[operandEn.index()] = memrefType;
2312e2c0738SRiver Riddle }
2322e2c0738SRiver Riddle });
2332e2c0738SRiver Riddle
2342e2c0738SRiver Riddle // We create a new function type and modify the function signature with this
2352e2c0738SRiver Riddle // new type.
2362e2c0738SRiver Riddle newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
2372e2c0738SRiver Riddle /*results=*/resultTypes);
2382e2c0738SRiver Riddle }
2392e2c0738SRiver Riddle
2402e2c0738SRiver Riddle // Since we update the function signature, it might affect the result types at
2412e2c0738SRiver Riddle // the caller site. Since this result might even be used by the caller
2422e2c0738SRiver Riddle // function in ReturnOps, the caller function's signature will also change.
2432e2c0738SRiver Riddle // Hence we record the caller function in 'funcOpsToUpdate' to update their
2442e2c0738SRiver Riddle // signature as well.
245*58ceae95SRiver Riddle llvm::SmallDenseSet<func::FuncOp, 8> funcOpsToUpdate;
2462e2c0738SRiver Riddle // We iterate over all symbolic uses of the function and update the return
2472e2c0738SRiver Riddle // type at the caller site.
2482e2c0738SRiver Riddle Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
2492e2c0738SRiver Riddle for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
2502e2c0738SRiver Riddle Operation *userOp = symbolUse.getUser();
2512e2c0738SRiver Riddle OpBuilder builder(userOp);
2522e2c0738SRiver Riddle // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
2532e2c0738SRiver Riddle // that the non-CallOp has no memrefs to be replaced.
2542e2c0738SRiver Riddle // TODO: Handle cases where a non-CallOp symbol use of a function deals with
2552e2c0738SRiver Riddle // memrefs.
25623aa5a74SRiver Riddle auto callOp = dyn_cast<func::CallOp>(userOp);
2572e2c0738SRiver Riddle if (!callOp)
2582e2c0738SRiver Riddle continue;
2592e2c0738SRiver Riddle Operation *newCallOp =
26023aa5a74SRiver Riddle builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
2612e2c0738SRiver Riddle resultTypes, userOp->getOperands());
2622e2c0738SRiver Riddle bool replacingMemRefUsesFailed = false;
2632e2c0738SRiver Riddle bool returnTypeChanged = false;
2642e2c0738SRiver Riddle for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
2652e2c0738SRiver Riddle OpResult oldResult = userOp->getResult(resIndex);
2662e2c0738SRiver Riddle OpResult newResult = newCallOp->getResult(resIndex);
2672e2c0738SRiver Riddle // This condition ensures that if the result is not of type memref or if
2682e2c0738SRiver Riddle // the resulting memref was already having a trivial map layout then we
2692e2c0738SRiver Riddle // need not perform any use replacement here.
2702e2c0738SRiver Riddle if (oldResult.getType() == newResult.getType())
2712e2c0738SRiver Riddle continue;
2722e2c0738SRiver Riddle AffineMap layoutMap =
2732e2c0738SRiver Riddle oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
2742e2c0738SRiver Riddle if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
2752e2c0738SRiver Riddle /*extraIndices=*/{},
2762e2c0738SRiver Riddle /*indexRemap=*/layoutMap,
2772e2c0738SRiver Riddle /*extraOperands=*/{},
2782e2c0738SRiver Riddle /*symbolOperands=*/{},
2792e2c0738SRiver Riddle /*domOpFilter=*/nullptr,
2802e2c0738SRiver Riddle /*postDomOpFilter=*/nullptr,
2812e2c0738SRiver Riddle /*allowNonDereferencingOps=*/true,
2822e2c0738SRiver Riddle /*replaceInDeallocOp=*/true))) {
2832e2c0738SRiver Riddle // If it failed (due to escapes for example), bail out.
2842e2c0738SRiver Riddle // It should never hit this part of the code because it is called by
2852e2c0738SRiver Riddle // only those functions which are normalizable.
2862e2c0738SRiver Riddle newCallOp->erase();
2872e2c0738SRiver Riddle replacingMemRefUsesFailed = true;
2882e2c0738SRiver Riddle break;
2892e2c0738SRiver Riddle }
2902e2c0738SRiver Riddle returnTypeChanged = true;
2912e2c0738SRiver Riddle }
2922e2c0738SRiver Riddle if (replacingMemRefUsesFailed)
2932e2c0738SRiver Riddle continue;
2942e2c0738SRiver Riddle // Replace all uses for other non-memref result types.
2952e2c0738SRiver Riddle userOp->replaceAllUsesWith(newCallOp);
2962e2c0738SRiver Riddle userOp->erase();
2972e2c0738SRiver Riddle if (returnTypeChanged) {
2982e2c0738SRiver Riddle // Since the return type changed it might lead to a change in function's
2992e2c0738SRiver Riddle // signature.
3002e2c0738SRiver Riddle // TODO: If funcOp doesn't return any memref type then no need to update
3012e2c0738SRiver Riddle // signature.
3022e2c0738SRiver Riddle // TODO: Further optimization - Check if the memref is indeed part of
3032e2c0738SRiver Riddle // ReturnOp at the parentFuncOp and only then updation of signature is
3042e2c0738SRiver Riddle // required.
3052e2c0738SRiver Riddle // TODO: Extend this for ops that are FunctionOpInterface. This would
3062e2c0738SRiver Riddle // require creating an OpInterface for FunctionOpInterface ops.
307*58ceae95SRiver Riddle func::FuncOp parentFuncOp = newCallOp->getParentOfType<func::FuncOp>();
3082e2c0738SRiver Riddle funcOpsToUpdate.insert(parentFuncOp);
3092e2c0738SRiver Riddle }
3102e2c0738SRiver Riddle }
3112e2c0738SRiver Riddle // Because external function's signature is already updated in
3122e2c0738SRiver Riddle // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
3132e2c0738SRiver Riddle if (!funcOp.isExternal())
3142e2c0738SRiver Riddle funcOp.setType(newFuncType);
3152e2c0738SRiver Riddle
3162e2c0738SRiver Riddle // Updating the signature type of those functions which call the current
3172e2c0738SRiver Riddle // function. Only if the return type of the current function has a normalized
3182e2c0738SRiver Riddle // memref will the caller function become a candidate for signature update.
319*58ceae95SRiver Riddle for (func::FuncOp parentFuncOp : funcOpsToUpdate)
3202e2c0738SRiver Riddle updateFunctionSignature(parentFuncOp, moduleOp);
3212e2c0738SRiver Riddle }
3222e2c0738SRiver Riddle
3232e2c0738SRiver Riddle /// Normalizes the memrefs within a function which includes those arising as a
3242e2c0738SRiver Riddle /// result of AllocOps, CallOps and function's argument. The ModuleOp argument
3252e2c0738SRiver Riddle /// is used to help update function's signature after normalization.
normalizeFuncOpMemRefs(func::FuncOp funcOp,ModuleOp moduleOp)326*58ceae95SRiver Riddle void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
3272e2c0738SRiver Riddle ModuleOp moduleOp) {
3282e2c0738SRiver Riddle // Turn memrefs' non-identity layouts maps into ones with identity. Collect
3292e2c0738SRiver Riddle // alloc ops first and then process since normalizeMemRef replaces/erases ops
3302e2c0738SRiver Riddle // during memref rewriting.
3312e2c0738SRiver Riddle SmallVector<memref::AllocOp, 4> allocOps;
3322e2c0738SRiver Riddle funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
3332e2c0738SRiver Riddle for (memref::AllocOp allocOp : allocOps)
3342e2c0738SRiver Riddle (void)normalizeMemRef(&allocOp);
3352e2c0738SRiver Riddle
3362e2c0738SRiver Riddle // We use this OpBuilder to create new memref layout later.
3372e2c0738SRiver Riddle OpBuilder b(funcOp);
3382e2c0738SRiver Riddle
3394a3460a7SRiver Riddle FunctionType functionType = funcOp.getFunctionType();
3402e2c0738SRiver Riddle SmallVector<Location> functionArgLocs(llvm::map_range(
3412e2c0738SRiver Riddle funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
3422e2c0738SRiver Riddle SmallVector<Type, 8> inputTypes;
3432e2c0738SRiver Riddle // Walk over each argument of a function to perform memref normalization (if
3442e2c0738SRiver Riddle for (unsigned argIndex :
3452e2c0738SRiver Riddle llvm::seq<unsigned>(0, functionType.getNumInputs())) {
3462e2c0738SRiver Riddle Type argType = functionType.getInput(argIndex);
3472e2c0738SRiver Riddle MemRefType memrefType = argType.dyn_cast<MemRefType>();
3482e2c0738SRiver Riddle // Check whether argument is of MemRef type. Any other argument type can
3492e2c0738SRiver Riddle // simply be part of the final function signature.
3502e2c0738SRiver Riddle if (!memrefType) {
3512e2c0738SRiver Riddle inputTypes.push_back(argType);
3522e2c0738SRiver Riddle continue;
3532e2c0738SRiver Riddle }
3542e2c0738SRiver Riddle // Fetch a new memref type after normalizing the old memref to have an
3552e2c0738SRiver Riddle // identity map layout.
3562e2c0738SRiver Riddle MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
3572e2c0738SRiver Riddle /*numSymbolicOperands=*/0);
3582e2c0738SRiver Riddle if (newMemRefType == memrefType || funcOp.isExternal()) {
3592e2c0738SRiver Riddle // Either memrefType already had an identity map or the map couldn't be
3602e2c0738SRiver Riddle // transformed to an identity map.
3612e2c0738SRiver Riddle inputTypes.push_back(newMemRefType);
3622e2c0738SRiver Riddle continue;
3632e2c0738SRiver Riddle }
3642e2c0738SRiver Riddle
3652e2c0738SRiver Riddle // Insert a new temporary argument with the new memref type.
3662e2c0738SRiver Riddle BlockArgument newMemRef = funcOp.front().insertArgument(
3672e2c0738SRiver Riddle argIndex, newMemRefType, functionArgLocs[argIndex]);
3682e2c0738SRiver Riddle BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
3692e2c0738SRiver Riddle AffineMap layoutMap = memrefType.getLayout().getAffineMap();
3702e2c0738SRiver Riddle // Replace all uses of the old memref.
3712e2c0738SRiver Riddle if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
3722e2c0738SRiver Riddle /*extraIndices=*/{},
3732e2c0738SRiver Riddle /*indexRemap=*/layoutMap,
3742e2c0738SRiver Riddle /*extraOperands=*/{},
3752e2c0738SRiver Riddle /*symbolOperands=*/{},
3762e2c0738SRiver Riddle /*domOpFilter=*/nullptr,
3772e2c0738SRiver Riddle /*postDomOpFilter=*/nullptr,
3782e2c0738SRiver Riddle /*allowNonDereferencingOps=*/true,
3792e2c0738SRiver Riddle /*replaceInDeallocOp=*/true))) {
3802e2c0738SRiver Riddle // If it failed (due to escapes for example), bail out. Removing the
3812e2c0738SRiver Riddle // temporary argument inserted previously.
3822e2c0738SRiver Riddle funcOp.front().eraseArgument(argIndex);
3832e2c0738SRiver Riddle continue;
3842e2c0738SRiver Riddle }
3852e2c0738SRiver Riddle
3862e2c0738SRiver Riddle // All uses for the argument with old memref type were replaced
3872e2c0738SRiver Riddle // successfully. So we remove the old argument now.
3882e2c0738SRiver Riddle funcOp.front().eraseArgument(argIndex + 1);
3892e2c0738SRiver Riddle }
3902e2c0738SRiver Riddle
3912e2c0738SRiver Riddle // Walk over normalizable operations to normalize memrefs of the operation
3922e2c0738SRiver Riddle // results. When `op` has memrefs with affine map in the operation results,
3932e2c0738SRiver Riddle // new operation containin normalized memrefs is created. Then, the memrefs
3942e2c0738SRiver Riddle // are replaced. `CallOp` is skipped here because it is handled in
3952e2c0738SRiver Riddle // `updateFunctionSignature()`.
3962e2c0738SRiver Riddle funcOp.walk([&](Operation *op) {
3972e2c0738SRiver Riddle if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
39823aa5a74SRiver Riddle op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
39923aa5a74SRiver Riddle !funcOp.isExternal()) {
4002e2c0738SRiver Riddle // Create newOp containing normalized memref in the operation result.
4012e2c0738SRiver Riddle Operation *newOp = createOpResultsNormalized(funcOp, op);
4022e2c0738SRiver Riddle // When all of the operation results have no memrefs or memrefs without
4032e2c0738SRiver Riddle // affine map, `newOp` is the same with `op` and following process is
4042e2c0738SRiver Riddle // skipped.
4052e2c0738SRiver Riddle if (op != newOp) {
4062e2c0738SRiver Riddle bool replacingMemRefUsesFailed = false;
4072e2c0738SRiver Riddle for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
4082e2c0738SRiver Riddle // Replace all uses of the old memrefs.
4092e2c0738SRiver Riddle Value oldMemRef = op->getResult(resIndex);
4102e2c0738SRiver Riddle Value newMemRef = newOp->getResult(resIndex);
4112e2c0738SRiver Riddle MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
4122e2c0738SRiver Riddle // Check whether the operation result is MemRef type.
4132e2c0738SRiver Riddle if (!oldMemRefType)
4142e2c0738SRiver Riddle continue;
4152e2c0738SRiver Riddle MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
4162e2c0738SRiver Riddle if (oldMemRefType == newMemRefType)
4172e2c0738SRiver Riddle continue;
4182e2c0738SRiver Riddle // TODO: Assume single layout map. Multiple maps not supported.
4192e2c0738SRiver Riddle AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
4202e2c0738SRiver Riddle if (failed(replaceAllMemRefUsesWith(oldMemRef,
4212e2c0738SRiver Riddle /*newMemRef=*/newMemRef,
4222e2c0738SRiver Riddle /*extraIndices=*/{},
4232e2c0738SRiver Riddle /*indexRemap=*/layoutMap,
4242e2c0738SRiver Riddle /*extraOperands=*/{},
4252e2c0738SRiver Riddle /*symbolOperands=*/{},
4262e2c0738SRiver Riddle /*domOpFilter=*/nullptr,
4272e2c0738SRiver Riddle /*postDomOpFilter=*/nullptr,
4282e2c0738SRiver Riddle /*allowNonDereferencingOps=*/true,
4292e2c0738SRiver Riddle /*replaceInDeallocOp=*/true))) {
4302e2c0738SRiver Riddle newOp->erase();
4312e2c0738SRiver Riddle replacingMemRefUsesFailed = true;
4322e2c0738SRiver Riddle continue;
4332e2c0738SRiver Riddle }
4342e2c0738SRiver Riddle }
4352e2c0738SRiver Riddle if (!replacingMemRefUsesFailed) {
4362e2c0738SRiver Riddle // Replace other ops with new op and delete the old op when the
4372e2c0738SRiver Riddle // replacement succeeded.
4382e2c0738SRiver Riddle op->replaceAllUsesWith(newOp);
4392e2c0738SRiver Riddle op->erase();
4402e2c0738SRiver Riddle }
4412e2c0738SRiver Riddle }
4422e2c0738SRiver Riddle }
4432e2c0738SRiver Riddle });
4442e2c0738SRiver Riddle
4452e2c0738SRiver Riddle // In a normal function, memrefs in the return type signature gets normalized
4462e2c0738SRiver Riddle // as a result of normalization of functions arguments, AllocOps or CallOps'
4472e2c0738SRiver Riddle // result types. Since an external function doesn't have a body, memrefs in
4482e2c0738SRiver Riddle // the return type signature can only get normalized by iterating over the
4492e2c0738SRiver Riddle // individual return types.
4502e2c0738SRiver Riddle if (funcOp.isExternal()) {
4512e2c0738SRiver Riddle SmallVector<Type, 4> resultTypes;
4522e2c0738SRiver Riddle for (unsigned resIndex :
4532e2c0738SRiver Riddle llvm::seq<unsigned>(0, functionType.getNumResults())) {
4542e2c0738SRiver Riddle Type resType = functionType.getResult(resIndex);
4552e2c0738SRiver Riddle MemRefType memrefType = resType.dyn_cast<MemRefType>();
4562e2c0738SRiver Riddle // Check whether result is of MemRef type. Any other argument type can
4572e2c0738SRiver Riddle // simply be part of the final function signature.
4582e2c0738SRiver Riddle if (!memrefType) {
4592e2c0738SRiver Riddle resultTypes.push_back(resType);
4602e2c0738SRiver Riddle continue;
4612e2c0738SRiver Riddle }
4622e2c0738SRiver Riddle // Computing a new memref type after normalizing the old memref to have an
4632e2c0738SRiver Riddle // identity map layout.
4642e2c0738SRiver Riddle MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
4652e2c0738SRiver Riddle /*numSymbolicOperands=*/0);
4662e2c0738SRiver Riddle resultTypes.push_back(newMemRefType);
4672e2c0738SRiver Riddle }
4682e2c0738SRiver Riddle
4692e2c0738SRiver Riddle FunctionType newFuncType =
4702e2c0738SRiver Riddle FunctionType::get(&getContext(), /*inputs=*/inputTypes,
4712e2c0738SRiver Riddle /*results=*/resultTypes);
4722e2c0738SRiver Riddle // Setting the new function signature for this external function.
4732e2c0738SRiver Riddle funcOp.setType(newFuncType);
4742e2c0738SRiver Riddle }
4752e2c0738SRiver Riddle updateFunctionSignature(funcOp, moduleOp);
4762e2c0738SRiver Riddle }
4772e2c0738SRiver Riddle
4782e2c0738SRiver Riddle /// Create an operation containing normalized memrefs in the operation results.
4792e2c0738SRiver Riddle /// When the results of `oldOp` have memrefs with affine map, the memrefs are
4802e2c0738SRiver Riddle /// normalized, and new operation containing them in the operation results is
4812e2c0738SRiver Riddle /// returned. If all of the results of `oldOp` have no memrefs or memrefs
4822e2c0738SRiver Riddle /// without affine map, `oldOp` is returned without modification.
createOpResultsNormalized(func::FuncOp funcOp,Operation * oldOp)483*58ceae95SRiver Riddle Operation *NormalizeMemRefs::createOpResultsNormalized(func::FuncOp funcOp,
4842e2c0738SRiver Riddle Operation *oldOp) {
4852e2c0738SRiver Riddle // Prepare OperationState to create newOp containing normalized memref in
4862e2c0738SRiver Riddle // the operation results.
4872e2c0738SRiver Riddle OperationState result(oldOp->getLoc(), oldOp->getName());
4882e2c0738SRiver Riddle result.addOperands(oldOp->getOperands());
4892e2c0738SRiver Riddle result.addAttributes(oldOp->getAttrs());
4902e2c0738SRiver Riddle // Add normalized MemRefType to the OperationState.
4912e2c0738SRiver Riddle SmallVector<Type, 4> resultTypes;
4922e2c0738SRiver Riddle OpBuilder b(funcOp);
4932e2c0738SRiver Riddle bool resultTypeNormalized = false;
4942e2c0738SRiver Riddle for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
4952e2c0738SRiver Riddle auto resultType = oldOp->getResult(resIndex).getType();
4962e2c0738SRiver Riddle MemRefType memrefType = resultType.dyn_cast<MemRefType>();
4972e2c0738SRiver Riddle // Check whether the operation result is MemRef type.
4982e2c0738SRiver Riddle if (!memrefType) {
4992e2c0738SRiver Riddle resultTypes.push_back(resultType);
5002e2c0738SRiver Riddle continue;
5012e2c0738SRiver Riddle }
5022e2c0738SRiver Riddle // Fetch a new memref type after normalizing the old memref.
5032e2c0738SRiver Riddle MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
5042e2c0738SRiver Riddle /*numSymbolicOperands=*/0);
5052e2c0738SRiver Riddle if (newMemRefType == memrefType) {
5062e2c0738SRiver Riddle // Either memrefType already had an identity map or the map couldn't
5072e2c0738SRiver Riddle // be transformed to an identity map.
5082e2c0738SRiver Riddle resultTypes.push_back(memrefType);
5092e2c0738SRiver Riddle continue;
5102e2c0738SRiver Riddle }
5112e2c0738SRiver Riddle resultTypes.push_back(newMemRefType);
5122e2c0738SRiver Riddle resultTypeNormalized = true;
5132e2c0738SRiver Riddle }
5142e2c0738SRiver Riddle result.addTypes(resultTypes);
5152e2c0738SRiver Riddle // When all of the results of `oldOp` have no memrefs or memrefs without
5162e2c0738SRiver Riddle // affine map, `oldOp` is returned without modification.
5172e2c0738SRiver Riddle if (resultTypeNormalized) {
5182e2c0738SRiver Riddle OpBuilder bb(oldOp);
5192e2c0738SRiver Riddle for (auto &oldRegion : oldOp->getRegions()) {
5202e2c0738SRiver Riddle Region *newRegion = result.addRegion();
5212e2c0738SRiver Riddle newRegion->takeBody(oldRegion);
5222e2c0738SRiver Riddle }
52314ecafd0SChia-hung Duan return bb.create(result);
5242e2c0738SRiver Riddle }
5252e2c0738SRiver Riddle return oldOp;
5262e2c0738SRiver Riddle }
527