//===-- AMDGPUReplaceLDSUseWithPointer.cpp --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all the uses of LDS within non-kernel functions by
// corresponding pointer counter-parts.
//
// The main motivation behind this pass is - to *avoid* subsequent LDS lowering
// pass from directly packing LDS (assume large LDS) into a struct type which
// would otherwise cause allocating huge memory for struct instance within every
// kernel.
//
// Brief sketch of the algorithm implemented in this pass is as below:
//
//   1. Collect all the LDS defined in the module which qualify for pointer
//      replacement, say it is, LDSGlobals set.
//
//   2. Collect all the reachable callees for each kernel defined in the module,
//      say it is, KernelToCallees map.
//
//   3. FOR (each global GV from LDSGlobals set) DO
//        LDSUsedNonKernels = Collect all non-kernel functions which use GV.
//        FOR (each kernel K in KernelToCallees map) DO
//           ReachableCallees = KernelToCallees[K]
//           ReachableAndLDSUsedCallees =
//              SetIntersect(LDSUsedNonKernels, ReachableCallees)
//           IF (ReachableAndLDSUsedCallees is not empty) THEN
//             Pointer = Create a pointer to point-to GV if not created.
//             Initialize Pointer to point-to GV within kernel K.
//           ENDIF
//        ENDFOR
//        Replace all uses of GV within non kernel functions by Pointer.
//      ENFOR
//
// LLVM IR example:
//
//    Input IR:
//
//    @lds = internal addrspace(3) global [4 x i32] undef, align 16
//
//    define internal void @f0() {
//    entry:
//      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* @lds,
//             i32 0, i32 0
//      ret void
//    }
//
//    define protected amdgpu_kernel void @k0() {
//    entry:
//      call void @f0()
//      ret void
//    }
//
//    Output IR:
//
//    @lds = internal addrspace(3) global [4 x i32] undef, align 16
//    @lds.ptr = internal unnamed_addr addrspace(3) global i16 undef, align 2
//
//    define internal void @f0() {
//    entry:
//      %0 = load i16, i16 addrspace(3)* @lds.ptr, align 2
//      %1 = getelementptr i8, i8 addrspace(3)* null, i16 %0
//      %2 = bitcast i8 addrspace(3)* %1 to [4 x i32] addrspace(3)*
//      %gep = getelementptr inbounds [4 x i32], [4 x i32] addrspace(3)* %2,
//             i32 0, i32 0
//      ret void
//    }
//
//    define protected amdgpu_kernel void @k0() {
//    entry:
//      store i16 ptrtoint ([4 x i32] addrspace(3)* @lds to i16),
//            i16 addrspace(3)* @lds.ptr, align 2
//      call void @f0()
//      ret void
//    }
//
//===----------------------------------------------------------------------===//

#include "AMDGPU.h"
#include "GCNSubtarget.h"
#include "Utils/AMDGPUBaseInfo.h"
#include "Utils/AMDGPULDSUtils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/ReplaceConstant.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <algorithm>
#include <vector>

#define DEBUG_TYPE "amdgpu-replace-lds-use-with-pointer"

using namespace llvm;

namespace {

class ReplaceLDSUseImpl {
  Module &M;
  LLVMContext &Ctx;
  const DataLayout &DL;
  Constant *LDSMemBaseAddr;

  DenseMap<GlobalVariable *, GlobalVariable *> LDSToPointer;
  DenseMap<GlobalVariable *, SmallPtrSet<Function *, 8>> LDSToNonKernels;
  DenseMap<Function *, SmallPtrSet<Function *, 8>> KernelToCallees;
  DenseMap<Function *, SmallPtrSet<GlobalVariable *, 8>> KernelToLDSPointers;
  DenseMap<Function *, BasicBlock *> KernelToInitBB;
  DenseMap<Function *, DenseMap<GlobalVariable *, Value *>>
      FunctionToLDSToReplaceInst;

  // Collect LDS which requires their uses to be replaced by pointer.
  std::vector<GlobalVariable *> collectLDSRequiringPointerReplace() {
    // Collect LDS which requires module lowering.
    std::vector<GlobalVariable *> LDSGlobals = AMDGPU::findVariablesToLower(M);

    // Remove LDS which don't qualify for replacement.
    LDSGlobals.erase(std::remove_if(LDSGlobals.begin(), LDSGlobals.end(),
                                    [&](GlobalVariable *GV) {
                                      return shouldIgnorePointerReplacement(GV);
                                    }),
                     LDSGlobals.end());

    return LDSGlobals;
  }

  // Returns true if uses of given LDS global within non-kernel functions should
  // be keep as it is without pointer replacement.
  bool shouldIgnorePointerReplacement(GlobalVariable *GV) {
    // LDS whose size is very small and doesn`t exceed pointer size is not worth
    // replacing.
    if (DL.getTypeAllocSize(GV->getValueType()) <= 2)
      return true;

    // LDS which is not used from non-kernel function scope or it is used from
    // global scope does not qualify for replacement.
    LDSToNonKernels[GV] = AMDGPU::collectNonKernelAccessorsOfLDS(GV);
    return LDSToNonKernels[GV].empty();

    // FIXME: When GV is used within all (or within most of the kernels), then
    // it does not make sense to create a pointer for it.
  }

  // Insert new global LDS pointer which points to LDS.
  GlobalVariable *createLDSPointer(GlobalVariable *GV) {
    // LDS pointer which points to LDS is already created? return it.
    auto PointerEntry = LDSToPointer.insert(std::make_pair(GV, nullptr));
    if (!PointerEntry.second)
      return PointerEntry.first->second;

    // We need to create new LDS pointer which points to LDS.
    //
    // Each CU owns at max 64K of LDS memory, so LDS address ranges from 0 to
    // 2^16 - 1. Hence 16 bit pointer is enough to hold the LDS address.
    auto *I16Ty = Type::getInt16Ty(Ctx);
    GlobalVariable *LDSPointer = new GlobalVariable(
        M, I16Ty, false, GlobalValue::InternalLinkage, UndefValue::get(I16Ty),
        GV->getName() + Twine(".ptr"), nullptr, GlobalVariable::NotThreadLocal,
        AMDGPUAS::LOCAL_ADDRESS);

    LDSPointer->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
    LDSPointer->setAlignment(AMDGPU::getAlign(DL, LDSPointer));

    // Mark that an associated LDS pointer is created for LDS.
    LDSToPointer[GV] = LDSPointer;

    return LDSPointer;
  }

  // Split entry basic block in such a way that only lane 0 of each wave does
  // the LDS pointer initialization, and return newly created basic block.
  BasicBlock *activateLaneZero(Function *K) {
    // If the entry basic block of kernel K is already splitted, then return
    // newly created basic block.
    auto BasicBlockEntry = KernelToInitBB.insert(std::make_pair(K, nullptr));
    if (!BasicBlockEntry.second)
      return BasicBlockEntry.first->second;

    // Split entry basic block of kernel K.
    auto *EI = &(*(K->getEntryBlock().getFirstInsertionPt()));
    IRBuilder<> Builder(EI);

    Value *Mbcnt =
        Builder.CreateIntrinsic(Intrinsic::amdgcn_mbcnt_lo, {},
                                {Builder.getInt32(-1), Builder.getInt32(0)});
    Value *Cond = Builder.CreateICmpEQ(Mbcnt, Builder.getInt32(0));
    Instruction *WB = cast<Instruction>(
        Builder.CreateIntrinsic(Intrinsic::amdgcn_wave_barrier, {}, {}));

    BasicBlock *NBB = SplitBlockAndInsertIfThen(Cond, WB, false)->getParent();

    // Mark that the entry basic block of kernel K is splitted.
    KernelToInitBB[K] = NBB;

    return NBB;
  }

  // Within given kernel, initialize given LDS pointer to point to given LDS.
  void initializeLDSPointer(Function *K, GlobalVariable *GV,
                            GlobalVariable *LDSPointer) {
    // If LDS pointer is already initialized within K, then nothing to do.
    auto PointerEntry = KernelToLDSPointers.insert(
        std::make_pair(K, SmallPtrSet<GlobalVariable *, 8>()));
    if (!PointerEntry.second)
      if (PointerEntry.first->second.contains(LDSPointer))
        return;

    // Insert instructions at EI which initialize LDS pointer to point-to LDS
    // within kernel K.
    //
    // That is, convert pointer type of GV to i16, and then store this converted
    // i16 value within LDSPointer which is of type i16*.
    auto *EI = &(*(activateLaneZero(K)->getFirstInsertionPt()));
    IRBuilder<> Builder(EI);
    Builder.CreateStore(Builder.CreatePtrToInt(GV, Type::getInt16Ty(Ctx)),
                        LDSPointer);

    // Mark that LDS pointer is initialized within kernel K.
    KernelToLDSPointers[K].insert(LDSPointer);
  }

  // We have created an LDS pointer for LDS, and initialized it to point-to LDS
  // within all relevent kernels. Now replace all the uses of LDS within
  // non-kernel functions by LDS pointer.
  void replaceLDSUseByPointer(GlobalVariable *GV, GlobalVariable *LDSPointer) {
    SmallVector<User *, 8> LDSUsers(GV->users());
    for (auto *U : LDSUsers) {
      // When `U` is a constant expression, it is possible that same constant
      // expression exists within multiple instructions, and within multiple
      // non-kernel functions. Collect all those non-kernel functions and all
      // those instructions within which `U` exist.
      auto FunctionToInsts =
          AMDGPU::getFunctionToInstsMap(U, false /*=CollectKernelInsts*/);

      for (auto FI = FunctionToInsts.begin(), FE = FunctionToInsts.end();
           FI != FE; ++FI) {
        Function *F = FI->first;
        auto &Insts = FI->second;
        for (auto *I : Insts) {
          // If `U` is a constant expression, then we need to break the
          // associated instruction into a set of separate instructions by
          // converting constant expressions into instructions.
          SmallPtrSet<Instruction *, 8> UserInsts;

          if (U == I) {
            // `U` is an instruction, conversion from constant expression to
            // set of instructions is *not* required.
            UserInsts.insert(I);
          } else {
            // `U` is a constant expression, convert it into corresponding set
            // of instructions.
            auto *CE = cast<ConstantExpr>(U);
            convertConstantExprsToInstructions(I, CE, &UserInsts);
          }

          // Go through all the user instrutions, if LDS exist within them as an
          // operand, then replace it by replace instruction.
          for (auto *II : UserInsts) {
            auto *ReplaceInst = getReplacementInst(F, GV, LDSPointer);
            II->replaceUsesOfWith(GV, ReplaceInst);
          }
        }
      }
    }
  }

  // Create a set of replacement instructions which together replace LDS within
  // non-kernel function F by accessing LDS indirectly using LDS pointer.
  Value *getReplacementInst(Function *F, GlobalVariable *GV,
                            GlobalVariable *LDSPointer) {
    // If the instruction which replaces LDS within F is already created, then
    // return it.
    auto LDSEntry = FunctionToLDSToReplaceInst.insert(
        std::make_pair(F, DenseMap<GlobalVariable *, Value *>()));
    if (!LDSEntry.second) {
      auto ReplaceInstEntry =
          LDSEntry.first->second.insert(std::make_pair(GV, nullptr));
      if (!ReplaceInstEntry.second)
        return ReplaceInstEntry.first->second;
    }

    // Get the instruction insertion point within the beginning of the entry
    // block of current non-kernel function.
    auto *EI = &(*(F->getEntryBlock().getFirstInsertionPt()));
    IRBuilder<> Builder(EI);

    // Insert required set of instructions which replace LDS within F.
    auto *V = Builder.CreateBitCast(
        Builder.CreateGEP(
            LDSMemBaseAddr,
            Builder.CreateLoad(LDSPointer->getValueType(), LDSPointer)),
        GV->getType());

    // Mark that the replacement instruction which replace LDS within F is
    // created.
    FunctionToLDSToReplaceInst[F][GV] = V;

    return V;
  }

public:
  ReplaceLDSUseImpl(Module &M)
      : M(M), Ctx(M.getContext()), DL(M.getDataLayout()) {
    LDSMemBaseAddr = Constant::getIntegerValue(
        PointerType::get(Type::getInt8Ty(M.getContext()),
                         AMDGPUAS::LOCAL_ADDRESS),
        APInt(32, 0));
  }

  // Entry-point function which interface ReplaceLDSUseImpl with outside of the
  // class.
  bool replaceLDSUse();

private:
  // For a given LDS from collected LDS globals set, replace its non-kernel
  // function scope uses by pointer.
  bool replaceLDSUse(GlobalVariable *GV);
};

// For given LDS from collected LDS globals set, replace its non-kernel function
// scope uses by pointer.
bool ReplaceLDSUseImpl::replaceLDSUse(GlobalVariable *GV) {
  // Holds all those non-kernel functions within which LDS is being accessed.
  SmallPtrSet<Function *, 8> &LDSAccessors = LDSToNonKernels[GV];

  // The LDS pointer which points to LDS and replaces all the uses of LDS.
  GlobalVariable *LDSPointer = nullptr;

  // Traverse through each kernel K, check and if required, initialize the
  // LDS pointer to point to LDS within K.
  for (auto KI = KernelToCallees.begin(), KE = KernelToCallees.end(); KI != KE;
       ++KI) {
    Function *K = KI->first;
    SmallPtrSet<Function *, 8> Callees = KI->second;

    // Compute reachable and LDS used callees for kernel K.
    set_intersect(Callees, LDSAccessors);

    // None of the LDS accessing non-kernel functions are reachable from
    // kernel K. Hence, no need to initialize LDS pointer within kernel K.
    if (Callees.empty())
      continue;

    // We have found reachable and LDS used callees for kernel K, and we need to
    // initialize LDS pointer within kernel K, and we need to replace LDS use
    // within those callees by LDS pointer.
    //
    // But, first check if LDS pointer is already created, if not create one.
    LDSPointer = createLDSPointer(GV);

    // Initialize LDS pointer to point to LDS within kernel K.
    initializeLDSPointer(K, GV, LDSPointer);
  }

  // We have not found reachable and LDS used callees for any of the kernels,
  // and hence we have not created LDS pointer.
  if (!LDSPointer)
    return false;

  // We have created an LDS pointer for LDS, and initialized it to point-to LDS
  // within all relevent kernels. Now replace all the uses of LDS within
  // non-kernel functions by LDS pointer.
  replaceLDSUseByPointer(GV, LDSPointer);

  return true;
}

// Entry-point function which interface ReplaceLDSUseImpl with outside of the
// class.
bool ReplaceLDSUseImpl::replaceLDSUse() {
  // Collect LDS which requires their uses to be replaced by pointer.
  std::vector<GlobalVariable *> LDSGlobals =
      collectLDSRequiringPointerReplace();

  // No LDS to pointer-replace. Nothing to do.
  if (LDSGlobals.empty())
    return false;

  // Collect reachable callee set for each kernel defined in the module.
  AMDGPU::collectReachableCallees(M, KernelToCallees);

  if (KernelToCallees.empty()) {
    // Either module does not have any kernel definitions, or none of the kernel
    // has a call to non-kernel functions, or we could not resolve any of the
    // call sites to proper non-kernel functions, because of the situations like
    // inline asm calls. Nothing to replace.
    return false;
  }

  // For every LDS from collected LDS globals set, replace its non-kernel
  // function scope use by pointer.
  bool Changed = false;
  for (auto *GV : LDSGlobals)
    Changed |= replaceLDSUse(GV);

  return Changed;
}

class AMDGPUReplaceLDSUseWithPointer : public ModulePass {
public:
  static char ID;

  AMDGPUReplaceLDSUseWithPointer() : ModulePass(ID) {
    initializeAMDGPUReplaceLDSUseWithPointerPass(
        *PassRegistry::getPassRegistry());
  }

  bool runOnModule(Module &M) override;

  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.addRequired<TargetPassConfig>();
  }
};

} // namespace

char AMDGPUReplaceLDSUseWithPointer::ID = 0;
char &llvm::AMDGPUReplaceLDSUseWithPointerID =
    AMDGPUReplaceLDSUseWithPointer::ID;

INITIALIZE_PASS_BEGIN(
    AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
    "Replace within non-kernel function use of LDS with pointer",
    false /*only look at the cfg*/, false /*analysis pass*/)
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(
    AMDGPUReplaceLDSUseWithPointer, DEBUG_TYPE,
    "Replace within non-kernel function use of LDS with pointer",
    false /*only look at the cfg*/, false /*analysis pass*/)

bool AMDGPUReplaceLDSUseWithPointer::runOnModule(Module &M) {
  ReplaceLDSUseImpl LDSUseReplacer{M};
  return LDSUseReplacer.replaceLDSUse();
}

ModulePass *llvm::createAMDGPUReplaceLDSUseWithPointerPass() {
  return new AMDGPUReplaceLDSUseWithPointer();
}

PreservedAnalyses
AMDGPUReplaceLDSUseWithPointerPass::run(Module &M, ModuleAnalysisManager &AM) {
  ReplaceLDSUseImpl LDSUseReplacer{M};
  LDSUseReplacer.replaceLDSUse();
  return PreservedAnalyses::all();
}
