1 //===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as 10 // composable rewrite patterns through a programmable CodegenStrategy object. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" 15 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/Dialect/Vector/VectorTransforms.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/LoopUtils.h" 22 #include "mlir/Transforms/Passes.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 #define DEBUG_TYPE "linalg-codegen-strategy" 28 29 void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { 30 MLIRContext *context = func.getContext(); 31 // Emplace patterns one at a time while also maintaining a simple chained 32 // state transition. 33 unsigned stepCount = 0; 34 SmallVector<FrozenRewritePatternList, 4> stage1Patterns; 35 auto zeroState = Identifier::get(std::to_string(stepCount), context); 36 auto currentState = zeroState; 37 for (const std::unique_ptr<Transformation> &t : transformationSequence) { 38 auto nextState = Identifier::get(std::to_string(++stepCount), context); 39 auto marker = (currentState == zeroState) 40 ? linalg::LinalgTransformationFilter( 41 t->filter, ArrayRef<Identifier>{}, nextState) 42 : linalg::LinalgTransformationFilter( 43 t->filter, currentState, nextState); 44 stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); 45 currentState = nextState; 46 } 47 48 OwningRewritePatternList stage2Patterns = 49 linalg::getLinalgTilingCanonicalizationPatterns(context); 50 stage2Patterns.insert<AffineMinSCFCanonicalizationPattern>(context); 51 52 auto stage3Transforms = [&](Operation *op) { 53 // Some of these may be too aggressive as a stage 3 that is applied on each 54 // stage 1 application and may have to be split out to post staged patterns 55 // application (in which case they could just be passes, TBD). 56 if (enableLICM) { 57 op->walk([&](LoopLikeOpInterface loopLike) { 58 LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n")); 59 if (failed(moveLoopInvariantCode(loopLike))) 60 llvm_unreachable("unexpected LICM failure"); 61 }); 62 } 63 promoteSingleIterationLoops(cast<FuncOp>(op)); 64 hoistViewAllocOps(cast<FuncOp>(op)); 65 hoistRedundantVectorTransfers(cast<FuncOp>(op)); 66 return success(); 67 }; 68 (void)linalg::applyStagedPatterns( 69 func, stage1Patterns, std::move(stage2Patterns), stage3Transforms); 70 71 //===--------------------------------------------------------------------===// 72 // Post staged patterns transforms 73 //===--------------------------------------------------------------------===// 74 75 // Programmatic splitting of slow/fast path vector transfers. 76 OwningRewritePatternList patterns; 77 patterns.insert<vector::VectorTransferFullPartialRewriter>( 78 context, vectorTransformsOptions); 79 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 80 81 // Programmatic controlled lowering of vector.contract only. 82 OwningRewritePatternList vectorContractLoweringPatterns; 83 vectorContractLoweringPatterns 84 .insert<ContractionOpToOuterProductOpLowering, 85 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 86 vectorTransformsOptions, context); 87 (void)applyPatternsAndFoldGreedily(func, 88 std::move(vectorContractLoweringPatterns)); 89 90 // Programmatic controlled lowering of vector.transfer only. 91 OwningRewritePatternList vectorToLoopsPatterns; 92 populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, 93 vectorToSCFOptions); 94 (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns)); 95 96 // Ensure we drop the marker in the end. 97 func.walk([](LinalgOp op) { 98 op.removeAttr(LinalgTransforms::kLinalgTransformMarker); 99 }); 100 } 101