1 //===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as 10 // composable rewrite patterns through a programmable CodegenStrategy object. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" 15 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/LoopUtils.h" 21 #include "mlir/Transforms/Passes.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 #define DEBUG_TYPE "linalg-codegen-strategy" 27 28 void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { 29 MLIRContext *context = func.getContext(); 30 // Emplace patterns one at a time while also maintaining a simple chained 31 // state transition. 32 unsigned stepCount = 0; 33 SmallVector<OwningRewritePatternList, 4> stage1Patterns; 34 auto zeroState = Identifier::get(std::to_string(stepCount), context); 35 auto currentState = zeroState; 36 for (const std::unique_ptr<Transformation> &t : transformationSequence) { 37 auto nextState = Identifier::get(std::to_string(++stepCount), context); 38 auto marker = (currentState == zeroState) 39 ? linalg::LinalgMarker({}, nextState) 40 : linalg::LinalgMarker(currentState, nextState); 41 stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); 42 currentState = nextState; 43 } 44 45 OwningRewritePatternList stage2Patterns = 46 linalg::getLinalgTilingCanonicalizationPatterns(context); 47 stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context); 48 49 auto stage3Transforms = [](Operation *op) { 50 // Some of these may be too aggressive as a stage 3 that is applied on each 51 // stage 1 application and may have to be split out to post staged patterns 52 // application (in which case they could just be passes, TBD). 53 PassManager pm(op->getContext()); 54 pm.addPass(createLoopInvariantCodeMotionPass()); 55 if (failed(pm.run(op->getParentOfType<ModuleOp>()))) 56 llvm_unreachable("Unexpected failure in cleanup pass pipeline."); 57 promoteSingleIterationLoops(cast<FuncOp>(op)); 58 hoistViewAllocOps(cast<FuncOp>(op)); 59 hoistRedundantVectorTransfers(cast<FuncOp>(op)); 60 return success(); 61 }; 62 linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, 63 stage3Transforms); 64 65 //===--------------------------------------------------------------------===// 66 // Post staged patterns transforms 67 //===--------------------------------------------------------------------===// 68 69 ModuleOp module = func.getParentOfType<ModuleOp>(); 70 71 // Programmatic splitting of slow/fast path vector transfers. 72 OwningRewritePatternList patterns; 73 patterns.insert<vector::VectorTransferFullPartialRewriter>( 74 context, vectorTransformsOptions); 75 applyPatternsAndFoldGreedily(module, patterns); 76 77 // Programmatic controlled lowering of vector.contract only. 78 OwningRewritePatternList vectorContractLoweringPatterns; 79 vectorContractLoweringPatterns 80 .insert<ContractionOpToOuterProductOpLowering, 81 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 82 vectorTransformsOptions, context); 83 applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns); 84 85 // Programmatic controlled lowering of vector.transfer only. 86 OwningRewritePatternList vectorToLoopsPatterns; 87 populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, 88 vectorToSCFOptions); 89 applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns); 90 91 // Ensure we drop the marker in the end. 92 module.walk([](LinalgOp op) { 93 op.removeAttr(LinalgTransforms::kLinalgTransformMarker); 94 }); 95 } 96