//===- Generalization.cpp - linalg named ops to generic ops  --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the Linalg generalization pass. It converts named
// Linalg ops to linalg.generic ops.
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/EDSC/Builders.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-generalization"

using namespace mlir;

// Creates a linalg.generic op from the given `namedOp`. Returns a null op if
// the given `namedOp` does not have a region builder.
static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
                                                    OpBuilder &builder) {
  auto regionBuilder = namedOp.getRegionBuilder();
  if (!regionBuilder) {
    LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n");
    return nullptr;
  }

  SmallVector<AffineMap, 4> indexingMaps = namedOp.getIndexingMaps();
  auto iterators = llvm::to_vector<4>(
      namedOp.iterator_types().getAsValueRange<StringAttr>());
  auto resultTypes = namedOp.getOutputTensorTypes();
  SmallVector<Type, 4> types(resultTypes.begin(), resultTypes.end());

  return builder.create<linalg::GenericOp>(
      namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputBuffers(),
      namedOp.getInitTensors(), indexingMaps, iterators,
      [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
        edsc::ScopedContext scope(bodyBuilder, loc);
        regionBuilder(*bodyBuilder.getBlock());
      });
}

namespace {

/// Base class for all linalg generalization patterns. A subclass must provide
/// the following method:
///   linalg::GenericOp createGenericOp(RootOp, PatternRewriter &)
/// for creating the generic op.
// TODO: remove this pattern after migrating all manually-written named ops
// into auto-generated ones.
template <typename ConcretePattern, typename RootOp>
struct LinalgGeneralizationPattern : OpRewritePattern<RootOp> {
  LinalgGeneralizationPattern(MLIRContext *context, linalg::LinalgMarker marker,
                              PatternBenefit benefit = 1)
      : OpRewritePattern<RootOp>(context, benefit), marker(std::move(marker)) {}

  LogicalResult matchAndRewrite(RootOp rootOp,
                                PatternRewriter &rewriter) const override {
    auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp.getOperation());
    if (!linalgOp)
      return failure();
    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
      return failure();

    auto *pattern = static_cast<const ConcretePattern *>(this);
    linalg::GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter);
    if (!genericOp)
      return failure();

    rewriter.replaceOp(rootOp, genericOp.getResults());
    marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
    return success();
  }

private:
  linalg::LinalgMarker marker;
};

struct GeneralizeConvOp
    : public LinalgGeneralizationPattern<GeneralizeConvOp, linalg::ConvOp> {
  using LinalgGeneralizationPattern::LinalgGeneralizationPattern;

  linalg::GenericOp createGenericOp(linalg::ConvOp, OpBuilder &rewriter) const;
};

/// Catch-all pattern for converting all named ops with a region builder into
/// linalg.generic.
struct LinalgNamedOpGeneralizationPattern : RewritePattern {
  LinalgNamedOpGeneralizationPattern(MLIRContext *context,
                                     linalg::LinalgMarker marker,
                                     PatternBenefit benefit = 1)
      : RewritePattern(benefit, MatchAnyOpTypeTag()),
        marker(std::move(marker)) {}

  LogicalResult matchAndRewrite(Operation *rootOp,
                                PatternRewriter &rewriter) const override {
    auto linalgOp = dyn_cast<linalg::LinalgOp>(rootOp);
    if (!linalgOp)
      return failure();
    if (failed(marker.checkAndNotify(rewriter, linalgOp)))
      return failure();

    // No nothing to do for linalg.generic and linalg.indexed_generic.
    if (isa<linalg::GenericOp, linalg::IndexedGenericOp>(rootOp))
      return failure();

    linalg::GenericOp genericOp =
        createGenericOpFromNamedOp(linalgOp, rewriter);
    if (!genericOp)
      return failure();

    rewriter.replaceOp(rootOp, genericOp.getResults());
    marker.replaceLinalgMarker(rewriter, genericOp.getOperation());
    return success();
  }

private:
  linalg::LinalgMarker marker;
};

struct LinalgGeneralizationPass
    : public LinalgGeneralizationBase<LinalgGeneralizationPass> {
  void runOnFunction() override;
};

} // namespace

void LinalgGeneralizationPass::runOnFunction() {
  FuncOp func = getFunction();
  OwningRewritePatternList patterns;
  linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
  linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
  applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
}

linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
                                                    OpBuilder &builder) const {
  SmallVector<AffineMap, 4> indexingMaps = convOp.getIndexingMaps();
  auto iterators =
      llvm::to_vector<4>(convOp.iterator_types().getAsValueRange<StringAttr>());
  return builder.create<linalg::GenericOp>(
      convOp.getLoc(), /*resultTensorTypes=*/ArrayRef<Type>(),
      convOp.getInputBuffers(), convOp.getOutputBuffers(),
      /*initTensors=*/ValueRange(), indexingMaps, iterators,
      [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) {
        Value mul =
            bodyBuilder.create<MulFOp>(bodyLoc, bodyArgs[0], bodyArgs[1]);
        Value add = bodyBuilder.create<AddFOp>(bodyLoc, mul, bodyArgs[2]);
        bodyBuilder.create<linalg::YieldOp>(bodyLoc, add);
      });
}

void mlir::linalg::populateLinalgConvGeneralizationPatterns(
    MLIRContext *context, OwningRewritePatternList &patterns,
    linalg::LinalgMarker marker) {
  patterns.insert<GeneralizeConvOp>(context, marker);
}

void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
    MLIRContext *context, OwningRewritePatternList &patterns,
    linalg::LinalgMarker marker) {
  patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
}

std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
  return std::make_unique<LinalgGeneralizationPass>();
}
