//===- Generalization.cpp - linalg named ops to generic ops --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Linalg generalization pass. It converts named // Linalg ops to linalg.generic ops. // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "linalg-generalization" using namespace mlir; using namespace mlir::linalg; // Creates a linalg.generic op from the given `namedOp`. Returns a null op if // the given `namedOp` does not have a region builder. static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, PatternRewriter &rewriter) { SmallVector inputOperands = namedOp.getInputOperands(); SmallVector outputOperands = namedOp.getOutputOperands(); SmallVector indexingMaps = namedOp.getIndexingMaps(); SmallVector iterators = llvm::to_vector<4>( namedOp.iterator_types().getAsValueRange()); SmallVector resultTypes = namedOp.getOutputTensorTypes(); SmallVector types(resultTypes.begin(), resultTypes.end()); // Inline the existing region if the named operation has a region attached. if (namedOp->getNumRegions() == 1) { GenericOp genericOp = rewriter.create(namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, iterators); rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), genericOp.region().begin()); return genericOp; } // Otherwise use the region builder to generate a new region. // TODO: Remove this path once all linag operations have a region attached. auto regionBuilder = namedOp.getRegionBuilder(); if (!regionBuilder) { LLVM_DEBUG(llvm::dbgs() << "no region builder for op: " << namedOp << "\n"); return nullptr; } return rewriter.create( namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { ImplicitLocOpBuilder b(loc, bodyBuilder); regionBuilder(b, *bodyBuilder.getBlock()); }); } namespace { /// Base class for all linalg generalization patterns. A subclass must provide /// the following method: /// GenericOp createGenericOp(RootOp, PatternRewriter &) /// for creating the generic op. // TODO: remove this pattern after migrating all manually-written named ops // into auto-generated ones. template struct LinalgGeneralizationPattern : OpRewritePattern { LinalgGeneralizationPattern(MLIRContext *context, LinalgTransformationFilter marker, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), marker(std::move(marker)) {} LogicalResult matchAndRewrite(RootOp rootOp, PatternRewriter &rewriter) const override { auto linalgOp = dyn_cast(rootOp.getOperation()); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); auto *pattern = static_cast(this); GenericOp genericOp = pattern->createGenericOp(rootOp, rewriter); if (!genericOp) return failure(); rewriter.replaceOp(rootOp, genericOp.getResults()); marker.replaceLinalgTransformationFilter(rewriter, genericOp.getOperation()); return success(); } private: LinalgTransformationFilter marker; }; struct GeneralizeConvOp : public LinalgGeneralizationPattern { using LinalgGeneralizationPattern::LinalgGeneralizationPattern; GenericOp createGenericOp(ConvOp convOp, OpBuilder &builder) const; }; /// Catch-all pattern for converting all named ops with a region builder into /// linalg.generic. struct LinalgNamedOpGeneralizationPattern : RewritePattern { LinalgNamedOpGeneralizationPattern(MLIRContext *context, LinalgTransformationFilter marker, PatternBenefit benefit = 1) : RewritePattern(MatchAnyOpTypeTag(), benefit, context), marker(std::move(marker)) {} LogicalResult matchAndRewrite(Operation *rootOp, PatternRewriter &rewriter) const override { auto linalgOp = dyn_cast(rootOp); if (!linalgOp) return failure(); if (failed(marker.checkAndNotify(rewriter, linalgOp))) return failure(); // No nothing to do for linalg.generic. if (isa(rootOp)) return failure(); GenericOp genericOp = createGenericOpFromNamedOp(linalgOp, rewriter); if (!genericOp) return failure(); rewriter.replaceOp(rootOp, genericOp.getResults()); marker.replaceLinalgTransformationFilter(rewriter, genericOp.getOperation()); return success(); } private: LinalgTransformationFilter marker; }; struct LinalgGeneralizationPass : public LinalgGeneralizationBase { void runOnFunction() override; }; } // namespace void LinalgGeneralizationPass::runOnFunction() { FuncOp func = getFunction(); RewritePatternSet patterns(&getContext()); populateLinalgConvGeneralizationPatterns(patterns); populateLinalgNamedOpsGeneralizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns)); } GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp, OpBuilder &builder) const { SmallVector indexingMaps = convOp.getIndexingMaps(); auto iterators = llvm::to_vector<4>(convOp.iterator_types().getAsValueRange()); SmallVector inputBuffers = convOp.getInputBufferOperands(); SmallVector outputBuffers = convOp.getOutputBufferOperands(); return builder.create( convOp.getLoc(), /*resultTensorTypes=*/ArrayRef(), inputBuffers, outputBuffers, indexingMaps, iterators, [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { Value mul = bodyBuilder.create(bodyLoc, bodyArgs[0], bodyArgs[1]); Value add = bodyBuilder.create(bodyLoc, mul, bodyArgs[2]); bodyBuilder.create(bodyLoc, add); }); } void mlir::linalg::populateLinalgConvGeneralizationPatterns( RewritePatternSet &patterns, LinalgTransformationFilter marker) { patterns.add(patterns.getContext(), marker); } void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( RewritePatternSet &patterns, LinalgTransformationFilter marker) { patterns.add(patterns.getContext(), marker); } std::unique_ptr> mlir::createLinalgGeneralizationPass() { return std::make_unique(); }