//===- NormalizeMemRefs.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an interprocedural pass to normalize memrefs to have
// identity layout maps.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "normalize-memrefs"

using namespace mlir;

namespace {

/// All memrefs passed across functions with non-trivial layout maps are
/// converted to ones with trivial identity layout ones.
/// If all the memref types/uses in a function are normalizable, we treat
/// such functions as normalizable. Also, if a normalizable function is known
/// to call a non-normalizable function, we treat that function as
/// non-normalizable as well. We assume external functions to be normalizable.
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
  void runOnOperation() override;
  void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
  bool areMemRefsNormalizable(FuncOp funcOp);
  void updateFunctionSignature(FuncOp funcOp, ModuleOp moduleOp);
  void setCalleesAndCallersNonNormalizable(FuncOp funcOp, ModuleOp moduleOp,
                                           DenseSet<FuncOp> &normalizableFuncs);
  Operation *createOpResultsNormalized(FuncOp funcOp, Operation *oldOp);
};

} // namespace

std::unique_ptr<OperationPass<ModuleOp>>
mlir::memref::createNormalizeMemRefsPass() {
  return std::make_unique<NormalizeMemRefs>();
}

void NormalizeMemRefs::runOnOperation() {
  LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
  ModuleOp moduleOp = getOperation();
  // We maintain all normalizable FuncOps in a DenseSet. It is initialized
  // with all the functions within a module and then functions which are not
  // normalizable are removed from this set.
  // TODO: Change this to work on FuncLikeOp once there is an operation
  // interface for it.
  DenseSet<FuncOp> normalizableFuncs;
  // Initialize `normalizableFuncs` with all the functions within a module.
  moduleOp.walk([&](FuncOp funcOp) { normalizableFuncs.insert(funcOp); });

  // Traverse through all the functions applying a filter which determines
  // whether that function is normalizable or not. All callers/callees of
  // a non-normalizable function will also become non-normalizable even if
  // they aren't passing any or specific non-normalizable memrefs. So,
  // functions which calls or get called by a non-normalizable becomes non-
  // normalizable functions themselves.
  moduleOp.walk([&](FuncOp funcOp) {
    if (normalizableFuncs.contains(funcOp)) {
      if (!areMemRefsNormalizable(funcOp)) {
        LLVM_DEBUG(llvm::dbgs()
                   << "@" << funcOp.getName()
                   << " contains ops that cannot normalize MemRefs\n");
        // Since this function is not normalizable, we set all the caller
        // functions and the callees of this function as not normalizable.
        // TODO: Drop this conservative assumption in the future.
        setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
                                            normalizableFuncs);
      }
    }
  });

  LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
                          << " functions\n");
  // Those functions which can be normalized are subjected to normalization.
  for (FuncOp &funcOp : normalizableFuncs)
    normalizeFuncOpMemRefs(funcOp, moduleOp);
}

/// Check whether all the uses of oldMemRef are either dereferencing uses or the
/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
/// are satisfied will the value become a candidate for replacement.
/// TODO: Extend this for DimOps.
static bool isMemRefNormalizable(Value::user_range opUsers) {
  return llvm::all_of(opUsers, [](Operation *op) {
    return op->hasTrait<OpTrait::MemRefsNormalizable>();
  });
}

/// Set all the calling functions and the callees of the function as not
/// normalizable.
void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
    FuncOp funcOp, ModuleOp moduleOp, DenseSet<FuncOp> &normalizableFuncs) {
  if (!normalizableFuncs.contains(funcOp))
    return;

  LLVM_DEBUG(
      llvm::dbgs() << "@" << funcOp.getName()
                   << " calls or is called by non-normalizable function\n");
  normalizableFuncs.erase(funcOp);
  // Caller of the function.
  Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
  for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
    // TODO: Extend this for ops that are FunctionOpInterface. This would
    // require creating an OpInterface for FunctionOpInterface ops.
    FuncOp parentFuncOp = symbolUse.getUser()->getParentOfType<FuncOp>();
    for (FuncOp &funcOp : normalizableFuncs) {
      if (parentFuncOp == funcOp) {
        setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
                                            normalizableFuncs);
        break;
      }
    }
  }

  // Functions called by this function.
  funcOp.walk([&](func::CallOp callOp) {
    StringAttr callee = callOp.getCalleeAttr().getAttr();
    for (FuncOp &funcOp : normalizableFuncs) {
      // We compare FuncOp and callee's name.
      if (callee == funcOp.getNameAttr()) {
        setCalleesAndCallersNonNormalizable(funcOp, moduleOp,
                                            normalizableFuncs);
        break;
      }
    }
  });
}

/// Check whether all the uses of AllocOps, CallOps and function arguments of a
/// function are either of dereferencing type or are uses in: DeallocOp, CallOp
/// or ReturnOp. Only if these constraints are satisfied will the function
/// become a candidate for normalization. We follow a conservative approach here
/// wherein even if the non-normalizable memref is not a part of the function's
/// argument or return type, we still label the entire function as
/// non-normalizable. We assume external functions to be normalizable.
bool NormalizeMemRefs::areMemRefsNormalizable(FuncOp funcOp) {
  // We assume external functions to be normalizable.
  if (funcOp.isExternal())
    return true;

  if (funcOp
          .walk([&](memref::AllocOp allocOp) -> WalkResult {
            Value oldMemRef = allocOp.getResult();
            if (!isMemRefNormalizable(oldMemRef.getUsers()))
              return WalkResult::interrupt();
            return WalkResult::advance();
          })
          .wasInterrupted())
    return false;

  if (funcOp
          .walk([&](func::CallOp callOp) -> WalkResult {
            for (unsigned resIndex :
                 llvm::seq<unsigned>(0, callOp.getNumResults())) {
              Value oldMemRef = callOp.getResult(resIndex);
              if (oldMemRef.getType().isa<MemRefType>())
                if (!isMemRefNormalizable(oldMemRef.getUsers()))
                  return WalkResult::interrupt();
            }
            return WalkResult::advance();
          })
          .wasInterrupted())
    return false;

  for (unsigned argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
    BlockArgument oldMemRef = funcOp.getArgument(argIndex);
    if (oldMemRef.getType().isa<MemRefType>())
      if (!isMemRefNormalizable(oldMemRef.getUsers()))
        return false;
  }

  return true;
}

/// Fetch the updated argument list and result of the function and update the
/// function signature. This updates the function's return type at the caller
/// site and in case the return type is a normalized memref then it updates
/// the calling function's signature.
/// TODO: An update to the calling function signature is required only if the
/// returned value is in turn used in ReturnOp of the calling function.
void NormalizeMemRefs::updateFunctionSignature(FuncOp funcOp,
                                               ModuleOp moduleOp) {
  FunctionType functionType = funcOp.getType();
  SmallVector<Type, 4> resultTypes;
  FunctionType newFuncType;
  resultTypes = llvm::to_vector<4>(functionType.getResults());

  // External function's signature was already updated in
  // 'normalizeFuncOpMemRefs()'.
  if (!funcOp.isExternal()) {
    SmallVector<Type, 8> argTypes;
    for (const auto &argEn : llvm::enumerate(funcOp.getArguments()))
      argTypes.push_back(argEn.value().getType());

    // Traverse ReturnOps to check if an update to the return type in the
    // function signature is required.
    funcOp.walk([&](func::ReturnOp returnOp) {
      for (const auto &operandEn : llvm::enumerate(returnOp.getOperands())) {
        Type opType = operandEn.value().getType();
        MemRefType memrefType = opType.dyn_cast<MemRefType>();
        // If type is not memref or if the memref type is same as that in
        // function's return signature then no update is required.
        if (!memrefType || memrefType == resultTypes[operandEn.index()])
          continue;
        // Update function's return type signature.
        // Return type gets normalized either as a result of function argument
        // normalization, AllocOp normalization or an update made at CallOp.
        // There can be many call flows inside a function and an update to a
        // specific ReturnOp has not yet been made. So we check that the result
        // memref type is normalized.
        // TODO: When selective normalization is implemented, handle multiple
        // results case where some are normalized, some aren't.
        if (memrefType.getLayout().isIdentity())
          resultTypes[operandEn.index()] = memrefType;
      }
    });

    // We create a new function type and modify the function signature with this
    // new type.
    newFuncType = FunctionType::get(&getContext(), /*inputs=*/argTypes,
                                    /*results=*/resultTypes);
  }

  // Since we update the function signature, it might affect the result types at
  // the caller site. Since this result might even be used by the caller
  // function in ReturnOps, the caller function's signature will also change.
  // Hence we record the caller function in 'funcOpsToUpdate' to update their
  // signature as well.
  llvm::SmallDenseSet<FuncOp, 8> funcOpsToUpdate;
  // We iterate over all symbolic uses of the function and update the return
  // type at the caller site.
  Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);
  for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
    Operation *userOp = symbolUse.getUser();
    OpBuilder builder(userOp);
    // When `userOp` can not be casted to `CallOp`, it is skipped. This assumes
    // that the non-CallOp has no memrefs to be replaced.
    // TODO: Handle cases where a non-CallOp symbol use of a function deals with
    // memrefs.
    auto callOp = dyn_cast<func::CallOp>(userOp);
    if (!callOp)
      continue;
    Operation *newCallOp =
        builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(),
                                     resultTypes, userOp->getOperands());
    bool replacingMemRefUsesFailed = false;
    bool returnTypeChanged = false;
    for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) {
      OpResult oldResult = userOp->getResult(resIndex);
      OpResult newResult = newCallOp->getResult(resIndex);
      // This condition ensures that if the result is not of type memref or if
      // the resulting memref was already having a trivial map layout then we
      // need not perform any use replacement here.
      if (oldResult.getType() == newResult.getType())
        continue;
      AffineMap layoutMap =
          oldResult.getType().cast<MemRefType>().getLayout().getAffineMap();
      if (failed(replaceAllMemRefUsesWith(oldResult, /*newMemRef=*/newResult,
                                          /*extraIndices=*/{},
                                          /*indexRemap=*/layoutMap,
                                          /*extraOperands=*/{},
                                          /*symbolOperands=*/{},
                                          /*domOpFilter=*/nullptr,
                                          /*postDomOpFilter=*/nullptr,
                                          /*allowNonDereferencingOps=*/true,
                                          /*replaceInDeallocOp=*/true))) {
        // If it failed (due to escapes for example), bail out.
        // It should never hit this part of the code because it is called by
        // only those functions which are normalizable.
        newCallOp->erase();
        replacingMemRefUsesFailed = true;
        break;
      }
      returnTypeChanged = true;
    }
    if (replacingMemRefUsesFailed)
      continue;
    // Replace all uses for other non-memref result types.
    userOp->replaceAllUsesWith(newCallOp);
    userOp->erase();
    if (returnTypeChanged) {
      // Since the return type changed it might lead to a change in function's
      // signature.
      // TODO: If funcOp doesn't return any memref type then no need to update
      // signature.
      // TODO: Further optimization - Check if the memref is indeed part of
      // ReturnOp at the parentFuncOp and only then updation of signature is
      // required.
      // TODO: Extend this for ops that are FunctionOpInterface. This would
      // require creating an OpInterface for FunctionOpInterface ops.
      FuncOp parentFuncOp = newCallOp->getParentOfType<FuncOp>();
      funcOpsToUpdate.insert(parentFuncOp);
    }
  }
  // Because external function's signature is already updated in
  // 'normalizeFuncOpMemRefs()', we don't need to update it here again.
  if (!funcOp.isExternal())
    funcOp.setType(newFuncType);

  // Updating the signature type of those functions which call the current
  // function. Only if the return type of the current function has a normalized
  // memref will the caller function become a candidate for signature update.
  for (FuncOp parentFuncOp : funcOpsToUpdate)
    updateFunctionSignature(parentFuncOp, moduleOp);
}

/// Normalizes the memrefs within a function which includes those arising as a
/// result of AllocOps, CallOps and function's argument. The ModuleOp argument
/// is used to help update function's signature after normalization.
void NormalizeMemRefs::normalizeFuncOpMemRefs(FuncOp funcOp,
                                              ModuleOp moduleOp) {
  // Turn memrefs' non-identity layouts maps into ones with identity. Collect
  // alloc ops first and then process since normalizeMemRef replaces/erases ops
  // during memref rewriting.
  SmallVector<memref::AllocOp, 4> allocOps;
  funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); });
  for (memref::AllocOp allocOp : allocOps)
    (void)normalizeMemRef(&allocOp);

  // We use this OpBuilder to create new memref layout later.
  OpBuilder b(funcOp);

  FunctionType functionType = funcOp.getType();
  SmallVector<Location> functionArgLocs(llvm::map_range(
      funcOp.getArguments(), [](BlockArgument arg) { return arg.getLoc(); }));
  SmallVector<Type, 8> inputTypes;
  // Walk over each argument of a function to perform memref normalization (if
  for (unsigned argIndex :
       llvm::seq<unsigned>(0, functionType.getNumInputs())) {
    Type argType = functionType.getInput(argIndex);
    MemRefType memrefType = argType.dyn_cast<MemRefType>();
    // Check whether argument is of MemRef type. Any other argument type can
    // simply be part of the final function signature.
    if (!memrefType) {
      inputTypes.push_back(argType);
      continue;
    }
    // Fetch a new memref type after normalizing the old memref to have an
    // identity map layout.
    MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
                                                   /*numSymbolicOperands=*/0);
    if (newMemRefType == memrefType || funcOp.isExternal()) {
      // Either memrefType already had an identity map or the map couldn't be
      // transformed to an identity map.
      inputTypes.push_back(newMemRefType);
      continue;
    }

    // Insert a new temporary argument with the new memref type.
    BlockArgument newMemRef = funcOp.front().insertArgument(
        argIndex, newMemRefType, functionArgLocs[argIndex]);
    BlockArgument oldMemRef = funcOp.getArgument(argIndex + 1);
    AffineMap layoutMap = memrefType.getLayout().getAffineMap();
    // Replace all uses of the old memref.
    if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newMemRef,
                                        /*extraIndices=*/{},
                                        /*indexRemap=*/layoutMap,
                                        /*extraOperands=*/{},
                                        /*symbolOperands=*/{},
                                        /*domOpFilter=*/nullptr,
                                        /*postDomOpFilter=*/nullptr,
                                        /*allowNonDereferencingOps=*/true,
                                        /*replaceInDeallocOp=*/true))) {
      // If it failed (due to escapes for example), bail out. Removing the
      // temporary argument inserted previously.
      funcOp.front().eraseArgument(argIndex);
      continue;
    }

    // All uses for the argument with old memref type were replaced
    // successfully. So we remove the old argument now.
    funcOp.front().eraseArgument(argIndex + 1);
  }

  // Walk over normalizable operations to normalize memrefs of the operation
  // results. When `op` has memrefs with affine map in the operation results,
  // new operation containin normalized memrefs is created. Then, the memrefs
  // are replaced. `CallOp` is skipped here because it is handled in
  // `updateFunctionSignature()`.
  funcOp.walk([&](Operation *op) {
    if (op->hasTrait<OpTrait::MemRefsNormalizable>() &&
        op->getNumResults() > 0 && !isa<func::CallOp>(op) &&
        !funcOp.isExternal()) {
      // Create newOp containing normalized memref in the operation result.
      Operation *newOp = createOpResultsNormalized(funcOp, op);
      // When all of the operation results have no memrefs or memrefs without
      // affine map, `newOp` is the same with `op` and following process is
      // skipped.
      if (op != newOp) {
        bool replacingMemRefUsesFailed = false;
        for (unsigned resIndex : llvm::seq<unsigned>(0, op->getNumResults())) {
          // Replace all uses of the old memrefs.
          Value oldMemRef = op->getResult(resIndex);
          Value newMemRef = newOp->getResult(resIndex);
          MemRefType oldMemRefType = oldMemRef.getType().dyn_cast<MemRefType>();
          // Check whether the operation result is MemRef type.
          if (!oldMemRefType)
            continue;
          MemRefType newMemRefType = newMemRef.getType().cast<MemRefType>();
          if (oldMemRefType == newMemRefType)
            continue;
          // TODO: Assume single layout map. Multiple maps not supported.
          AffineMap layoutMap = oldMemRefType.getLayout().getAffineMap();
          if (failed(replaceAllMemRefUsesWith(oldMemRef,
                                              /*newMemRef=*/newMemRef,
                                              /*extraIndices=*/{},
                                              /*indexRemap=*/layoutMap,
                                              /*extraOperands=*/{},
                                              /*symbolOperands=*/{},
                                              /*domOpFilter=*/nullptr,
                                              /*postDomOpFilter=*/nullptr,
                                              /*allowNonDereferencingOps=*/true,
                                              /*replaceInDeallocOp=*/true))) {
            newOp->erase();
            replacingMemRefUsesFailed = true;
            continue;
          }
        }
        if (!replacingMemRefUsesFailed) {
          // Replace other ops with new op and delete the old op when the
          // replacement succeeded.
          op->replaceAllUsesWith(newOp);
          op->erase();
        }
      }
    }
  });

  // In a normal function, memrefs in the return type signature gets normalized
  // as a result of normalization of functions arguments, AllocOps or CallOps'
  // result types. Since an external function doesn't have a body, memrefs in
  // the return type signature can only get normalized by iterating over the
  // individual return types.
  if (funcOp.isExternal()) {
    SmallVector<Type, 4> resultTypes;
    for (unsigned resIndex :
         llvm::seq<unsigned>(0, functionType.getNumResults())) {
      Type resType = functionType.getResult(resIndex);
      MemRefType memrefType = resType.dyn_cast<MemRefType>();
      // Check whether result is of MemRef type. Any other argument type can
      // simply be part of the final function signature.
      if (!memrefType) {
        resultTypes.push_back(resType);
        continue;
      }
      // Computing a new memref type after normalizing the old memref to have an
      // identity map layout.
      MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
                                                     /*numSymbolicOperands=*/0);
      resultTypes.push_back(newMemRefType);
    }

    FunctionType newFuncType =
        FunctionType::get(&getContext(), /*inputs=*/inputTypes,
                          /*results=*/resultTypes);
    // Setting the new function signature for this external function.
    funcOp.setType(newFuncType);
  }
  updateFunctionSignature(funcOp, moduleOp);
}

/// Create an operation containing normalized memrefs in the operation results.
/// When the results of `oldOp` have memrefs with affine map, the memrefs are
/// normalized, and new operation containing them in the operation results is
/// returned. If all of the results of `oldOp` have no memrefs or memrefs
/// without affine map, `oldOp` is returned without modification.
Operation *NormalizeMemRefs::createOpResultsNormalized(FuncOp funcOp,
                                                       Operation *oldOp) {
  // Prepare OperationState to create newOp containing normalized memref in
  // the operation results.
  OperationState result(oldOp->getLoc(), oldOp->getName());
  result.addOperands(oldOp->getOperands());
  result.addAttributes(oldOp->getAttrs());
  // Add normalized MemRefType to the OperationState.
  SmallVector<Type, 4> resultTypes;
  OpBuilder b(funcOp);
  bool resultTypeNormalized = false;
  for (unsigned resIndex : llvm::seq<unsigned>(0, oldOp->getNumResults())) {
    auto resultType = oldOp->getResult(resIndex).getType();
    MemRefType memrefType = resultType.dyn_cast<MemRefType>();
    // Check whether the operation result is MemRef type.
    if (!memrefType) {
      resultTypes.push_back(resultType);
      continue;
    }
    // Fetch a new memref type after normalizing the old memref.
    MemRefType newMemRefType = normalizeMemRefType(memrefType, b,
                                                   /*numSymbolicOperands=*/0);
    if (newMemRefType == memrefType) {
      // Either memrefType already had an identity map or the map couldn't
      // be transformed to an identity map.
      resultTypes.push_back(memrefType);
      continue;
    }
    resultTypes.push_back(newMemRefType);
    resultTypeNormalized = true;
  }
  result.addTypes(resultTypes);
  // When all of the results of `oldOp` have no memrefs or memrefs without
  // affine map, `oldOp` is returned without modification.
  if (resultTypeNormalized) {
    OpBuilder bb(oldOp);
    for (auto &oldRegion : oldOp->getRegions()) {
      Region *newRegion = result.addRegion();
      newRegion->takeBody(oldRegion);
    }
    return bb.createOperation(result);
  }
  return oldOp;
}
