1 //===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as
10 // composable rewrite patterns through a programmable CodegenStrategy object.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
15 
16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/Dialect/Vector/VectorTransforms.h"
19 #include "mlir/Pass/PassManager.h"
20 #include "mlir/Transforms/LoopUtils.h"
21 #include "mlir/Transforms/Passes.h"
22 
23 using namespace mlir;
24 using namespace mlir::linalg;
25 
26 #define DEBUG_TYPE "linalg-codegen-strategy"
27 
28 void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
29   MLIRContext *context = func.getContext();
30   // Emplace patterns one at a time while also maintaining a simple chained
31   // state transition.
32   unsigned stepCount = 0;
33   SmallVector<OwningRewritePatternList, 4> stage1Patterns;
34   auto zeroState = Identifier::get(std::to_string(stepCount), context);
35   auto currentState = zeroState;
36   for (const std::unique_ptr<Transformation> &t : transformationSequence) {
37     auto nextState = Identifier::get(std::to_string(++stepCount), context);
38     auto marker = (currentState == zeroState)
39                       ? linalg::LinalgMarker({}, nextState)
40                       : linalg::LinalgMarker(currentState, nextState);
41     stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker));
42     currentState = nextState;
43   }
44 
45   OwningRewritePatternList stage2Patterns =
46       linalg::getLinalgTilingCanonicalizationPatterns(context);
47   stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context);
48 
49   auto stage3Transforms = [](Operation *op) {
50     // Some of these may be too aggressive as a stage 3 that is applied on each
51     // stage 1 application and may have to be split out to post staged patterns
52     // application (in which case they could just be passes, TBD).
53     PassManager pm(op->getContext());
54     pm.addPass(createLoopInvariantCodeMotionPass());
55     if (failed(pm.run(op->getParentOfType<ModuleOp>())))
56       llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
57     promoteSingleIterationLoops(cast<FuncOp>(op));
58     hoistViewAllocOps(cast<FuncOp>(op));
59     hoistRedundantVectorTransfers(cast<FuncOp>(op));
60     return success();
61   };
62   linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns,
63                               stage3Transforms);
64 
65   //===--------------------------------------------------------------------===//
66   // Post staged patterns transforms
67   //===--------------------------------------------------------------------===//
68 
69   ModuleOp module = func.getParentOfType<ModuleOp>();
70 
71   // Programmatic splitting of slow/fast path vector transfers.
72   OwningRewritePatternList patterns;
73   patterns.insert<vector::VectorTransferFullPartialRewriter>(
74       context, vectorTransformsOptions);
75   applyPatternsAndFoldGreedily(module, patterns);
76 
77   // Programmatic controlled lowering of vector.contract only.
78   OwningRewritePatternList vectorContractLoweringPatterns;
79   vectorContractLoweringPatterns
80       .insert<ContractionOpToOuterProductOpLowering,
81               ContractionOpToMatmulOpLowering, ContractionOpLowering>(
82           vectorTransformsOptions, context);
83   applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns);
84 
85   // Programmatic controlled lowering of vector.transfer only.
86   OwningRewritePatternList vectorToLoopsPatterns;
87   populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
88                                         vectorToSCFOptions);
89   applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns);
90 
91   // Ensure we drop the marker in the end.
92   module.walk([](LinalgOp op) {
93     op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
94   });
95 }
96