//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements in-dialect lowering of the all-reduce op to a block of
// simpler instructions.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {

struct GpuAllReduceRewriter {
  using AccumulatorFactory = std::function<Value(Value, Value)>;

  GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_,
                       PatternRewriter &rewriter_)
      : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
        loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
        indexType(IndexType::get(reduceOp.getContext())),
        int32Type(IntegerType::get(reduceOp.getContext(), /*width=*/32)) {}

  /// Creates an all_reduce across the workgroup.
  ///
  /// First reduce the elements within a subgroup. The first invocation of each
  /// subgroup writes the intermediate result to workgroup memory. After
  /// synchronizing the workgroup, the first subgroup reduces the values from
  /// workgroup memory. The result is broadcasted to all invocations through
  /// workgroup memory.
  ///
  ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
  ///     cond_br %is_first_lane, ^then1, ^continue1
  ///   ^then1:
  ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
  ///     br ^continue1
  ///   ^continue1:
  ///     gpu.barrier
  ///     %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups
  ///     cond_br %is_valid_subgroup, ^then2, ^continue2
  ///   ^then2:
  ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
  ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
  ///     store %all_reduce, %workgroup_buffer[%zero]
  ///     llvm.br ^continue2
  ///   ^continue2:
  ///     gpu.barrier
  ///     %result = load %workgroup_buffer[%zero]
  ///     return %result
  ///
  void rewrite() {
    rewriter.setInsertionPoint(reduceOp);

    // Compute linear invocation index and workgroup size.
    Value dimX = getDimOp<gpu::BlockDimOp>("x");
    Value dimY = getDimOp<gpu::BlockDimOp>("y");
    Value dimZ = getDimOp<gpu::BlockDimOp>("z");
    Value tidX = getDimOp<gpu::ThreadIdOp>("x");
    Value tidY = getDimOp<gpu::ThreadIdOp>("y");
    Value tidZ = getDimOp<gpu::ThreadIdOp>("z");
    Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY);
    Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY);
    Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX);
    Value tmp4 = create<MulIOp>(int32Type, dimX, dimY);
    Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX);
    Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ);

    // Compute lane id (invocation id withing the subgroup).
    Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type);
    Value laneId = create<AndOp>(invocationIdx, subgroupMask);
    Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId,
                                       create<ConstantIntOp>(0, int32Type));

    Value numThreadsWithSmallerSubgroupId =
        create<SubIOp>(invocationIdx, laneId);
    // The number of active invocations starting from the current subgroup.
    // The consumers do not require the value to be clamped to the size of the
    // subgroup.
    Value activeWidth =
        create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);

    // Create factory for op which accumulates to values.
    AccumulatorFactory accumFactory = getFactory();
    assert(accumFactory && "failed to create accumulator factory");

    // Reduce elements within each subgroup to produce the intermediate results.
    Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
                                                reduceOp.value(), accumFactory);

    // Add workgroup buffer to parent function for intermediate result.
    Value buffer = createWorkgroupBuffer();

    // Write the intermediate results to workgroup memory, using the first lane
    // of each subgroup.
    createPredicatedBlock(isFirstLane, [&] {
      Value subgroupId = getDivideBySubgroupSize(invocationIdx);
      Value index = create<IndexCastOp>(indexType, subgroupId);
      create<memref::StoreOp>(subgroupReduce, buffer, index);
    });
    create<gpu::BarrierOp>();

    // Compute number of active subgroups.
    Value biasedBlockSize =
        create<AddIOp>(int32Type, workgroupSize, subgroupMask);
    Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
    Value isValidSubgroup =
        create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups);

    // Use the first numSubgroups invocations to reduce the intermediate results
    // from workgroup memory. The final result is written to workgroup memory
    // again.
    Value zero = create<ConstantIndexOp>(0);
    createPredicatedBlock(isValidSubgroup, [&] {
      Value index = create<IndexCastOp>(indexType, invocationIdx);
      Value value = create<memref::LoadOp>(valueType, buffer, index);
      Value result =
          createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
      create<memref::StoreOp>(result, buffer, zero);
    });

    // Synchronize workgroup and load result from workgroup memory.
    create<gpu::BarrierOp>();
    Value result = create<memref::LoadOp>(valueType, buffer, zero);

    rewriter.replaceOp(reduceOp, result);
  }

private:
  // Shortcut to create an op from rewriter using loc as the first argument.
  template <typename T, typename... Args>
  T create(Args... args) {
    return rewriter.create<T>(loc, std::forward<Args>(args)...);
  }

  // Creates dimension op of type T, with the result casted to int32.
  template <typename T>
  Value getDimOp(StringRef dimension) {
    Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
    return create<IndexCastOp>(int32Type, dim);
  }

  /// Adds type to funcOp's workgroup attributions.
  Value createWorkgroupBuffer() {
    int workgroupMemoryAddressSpace =
        gpu::GPUDialect::getWorkgroupAddressSpace();
    auto bufferType =
        MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{},
                        workgroupMemoryAddressSpace);
    return funcOp.addWorkgroupAttribution(bufferType);
  }

  /// Returns an accumulator factory using either the op attribute or the body
  /// region.
  AccumulatorFactory getFactory() {
    auto &body = reduceOp.body();
    if (!body.empty())
      return getFactory(body);
    auto opAttr = reduceOp.op();
    if (opAttr)
      return getFactory(*opAttr);
    return AccumulatorFactory();
  }

  /// Returns an accumulator factory that clones the body. The body's entry
  /// block is expected to have 2 arguments. The gpu.yield return the
  /// accumulated value of the same type.
  AccumulatorFactory getFactory(Region &body) {
    return AccumulatorFactory([&](Value lhs, Value rhs) {
      Block *block = rewriter.getInsertionBlock();
      Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());

      // Insert accumulator body between split block.
      BlockAndValueMapping mapping;
      mapping.map(body.getArgument(0), lhs);
      mapping.map(body.getArgument(1), rhs);
      rewriter.cloneRegionBefore(body, *split->getParent(),
                                 split->getIterator(), mapping);

      // Add branch before inserted body, into body.
      block = block->getNextNode();
      create<BranchOp>(block, ValueRange());

      // Replace all gpu.yield ops with branch out of body.
      for (; block != split; block = block->getNextNode()) {
        Operation *terminator = block->getTerminator();
        if (!isa<gpu::YieldOp>(terminator))
          continue;
        rewriter.setInsertionPointToEnd(block);
        rewriter.replaceOpWithNewOp<BranchOp>(
            terminator, split, ValueRange(terminator->getOperand(0)));
      }

      // Return accumulator result.
      rewriter.setInsertionPointToStart(split);
      return split->addArgument(lhs.getType());
    });
  }

  /// Returns an accumulator factory that creates an op specified by opName.
  AccumulatorFactory getFactory(StringRef opName) {
    bool isFloatingPoint = valueType.isa<FloatType>();
    if (opName == "add")
      return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
    if (opName == "mul")
      return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
    if (opName == "and") {
      return getFactory<AndOp>();
    }
    if (opName == "or") {
      return getFactory<OrOp>();
    }
    if (opName == "xor") {
      return getFactory<XOrOp>();
    }
    if (opName == "max") {
      return isFloatingPoint
                 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
                 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
    }
    if (opName == "min") {
      return isFloatingPoint
                 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
                 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
    }
    return AccumulatorFactory();
  }

  /// Returns an accumulator factory that creates an op of type T.
  template <typename T>
  AccumulatorFactory getFactory() {
    return [&](Value lhs, Value rhs) {
      return create<T>(lhs.getType(), lhs, rhs);
    };
  }

  /// Returns an accumulator for comparison such as min, max. T is the type
  /// of the compare op.
  template <typename T, typename PredicateEnum, PredicateEnum predicate>
  AccumulatorFactory getCmpFactory() const {
    return [&](Value lhs, Value rhs) {
      Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
      return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
    };
  }

  /// Creates an if-block skeleton and calls the two factories to generate the
  /// ops in the `then` and `else` block..
  ///
  ///     llvm.cond_br %condition, ^then, ^continue
  ///   ^then:
  ///     %then_operands = `thenOpsFactory()`
  ///     llvm.br ^continue(%then_operands)
  ///   ^else:
  ///     %else_operands = `elseOpsFactory()`
  ///     llvm.br ^continue(%else_operands)
  ///   ^continue(%block_operands):
  ///
  template <typename ThenOpsFactory, typename ElseOpsFactory>
  void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
                ElseOpsFactory &&elseOpsFactory) {
    Block *currentBlock = rewriter.getInsertionBlock();
    auto currentPoint = rewriter.getInsertionPoint();

    Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
    Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
    Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());

    rewriter.setInsertionPointToEnd(currentBlock);
    create<CondBranchOp>(condition, thenBlock,
                         /*trueOperands=*/ArrayRef<Value>(), elseBlock,
                         /*falseOperands=*/ArrayRef<Value>());

    rewriter.setInsertionPointToStart(thenBlock);
    auto thenOperands = thenOpsFactory();
    create<BranchOp>(continueBlock, thenOperands);

    rewriter.setInsertionPointToStart(elseBlock);
    auto elseOperands = elseOpsFactory();
    create<BranchOp>(continueBlock, elseOperands);

    assert(thenOperands.size() == elseOperands.size());
    rewriter.setInsertionPointToStart(continueBlock);
    for (auto operand : thenOperands)
      continueBlock->addArgument(operand.getType());
  }

  /// Shortcut for createIf with empty else block and no block operands.
  template <typename Factory>
  void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
    static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
                  "predicatedOpsFactory should not return any value");
    createIf(
        condition,
        [&] {
          predicatedOpsFactory();
          return ArrayRef<Value>();
        },
        [&] { return ArrayRef<Value>(); });
  }

  /// Creates a reduction across the first activeWidth lanes of a subgroup, or
  /// the entire subgroup if activeWidth is larger than the subgroup width.
  /// The first lane returns the result, all others return values are undefined.
  Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
                             AccumulatorFactory &accumFactory) {
    Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
    Value isPartialSubgroup =
        create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize);
    std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
    auto xorAttr = rewriter.getStringAttr("xor");

    createIf(
        isPartialSubgroup,
        // Generate reduction over a (potentially) partial subgroup.
        [&] {
          Value value = operand;
          // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
          // lane is within the active range. The accumulated value is available
          // in the first lane.
          for (int i = 1; i < kSubgroupSize; i <<= 1) {
            Value offset = create<ConstantIntOp>(i, int32Type);
            auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
                                                    activeWidth, xorAttr);
            // Skip the accumulation if the shuffle op read from a lane outside
            // of the active range.
            createIf(
                shuffleOp.getResult(1),
                [&] {
                  return SmallVector<Value, 1>{
                      accumFactory(value, shuffleOp.getResult(0))};
                },
                [&] { return llvm::makeArrayRef(value); });
            value = rewriter.getInsertionBlock()->getArgument(0);
          }
          return SmallVector<Value, 1>{value};
        },
        // Generate a reduction over the entire subgroup. This is a
        // specialization of the above reduction with unconditional
        // accumulation.
        [&] {
          Value value = operand;
          for (int i = 1; i < kSubgroupSize; i <<= 1) {
            Value offset = create<ConstantIntOp>(i, int32Type);
            auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
                                                    subgroupSize, xorAttr);
            value = accumFactory(value, shuffleOp.getResult(0));
          }
          return SmallVector<Value, 1>{value};
        });
    return rewriter.getInsertionBlock()->getArgument(0);
  }

  /// Returns value divided by the subgroup size (i.e. 32).
  Value getDivideBySubgroupSize(Value value) {
    Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
    return create<SignedDivIOp>(int32Type, value, subgroupSize);
  }

  gpu::GPUFuncOp funcOp;
  gpu::AllReduceOp reduceOp;
  PatternRewriter &rewriter;

  Location loc;
  Type valueType;
  Type indexType;
  Type int32Type;

  static constexpr int kSubgroupSize = 32;
};

struct GpuAllReduceConversion : public RewritePattern {
  explicit GpuAllReduceConversion(MLIRContext *context)
      : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}

  LogicalResult matchAndRewrite(Operation *op,
                                PatternRewriter &rewriter) const override {
    auto funcOp = cast<gpu::GPUFuncOp>(op);
    auto callback = [&](gpu::AllReduceOp reduceOp) {
      GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
      // Performing a rewrite invalidates the walk iterator. Report interrupt
      // so that we can start a new walk until all all_reduce ops are replaced.
      return WalkResult::interrupt();
    };
    while (funcOp.walk(callback).wasInterrupted()) {
    }
    return success();
  }
};
} // namespace

void mlir::populateGpuAllReducePatterns(RewritePatternSet &patterns) {
  patterns.add<GpuAllReduceConversion>(patterns.getContext());
}
