1 //===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as 10 // composable rewrite patterns through a programmable CodegenStrategy object. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" 15 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/SCF/Transforms.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/Dialect/Vector/VectorTransforms.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "mlir/Transforms/LoopUtils.h" 23 #include "mlir/Transforms/Passes.h" 24 25 using namespace mlir; 26 using namespace mlir::linalg; 27 28 #define DEBUG_TYPE "linalg-codegen-strategy" 29 30 void mlir::linalg::CodegenStrategy::transform(FuncOp func) const { 31 MLIRContext *context = func.getContext(); 32 // Emplace patterns one at a time while also maintaining a simple chained 33 // state transition. 34 unsigned stepCount = 0; 35 SmallVector<FrozenRewritePatternSet, 4> stage1Patterns; 36 auto zeroState = Identifier::get(std::to_string(stepCount), context); 37 auto currentState = zeroState; 38 for (const std::unique_ptr<Transformation> &t : transformationSequence) { 39 auto nextState = Identifier::get(std::to_string(++stepCount), context); 40 auto marker = (currentState == zeroState) 41 ? linalg::LinalgTransformationFilter( 42 t->filter, ArrayRef<Identifier>{}, nextState) 43 : linalg::LinalgTransformationFilter( 44 t->filter, currentState, nextState); 45 stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); 46 currentState = nextState; 47 } 48 49 RewritePatternSet stage2Patterns = 50 linalg::getLinalgTilingCanonicalizationPatterns(context); 51 scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); 52 53 auto stage3Transforms = [&](Operation *op) { 54 // Some of these may be too aggressive as a stage 3 that is applied on each 55 // stage 1 application and may have to be split out to post staged patterns 56 // application (in which case they could just be passes, TBD). 57 if (lateCodegenStrategyOptions.enableLICM) { 58 op->walk([&](LoopLikeOpInterface loopLike) { 59 LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n")); 60 if (failed(moveLoopInvariantCode(loopLike))) 61 llvm_unreachable("unexpected LICM failure"); 62 }); 63 } 64 promoteSingleIterationLoops(cast<FuncOp>(op)); 65 if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfers) 66 hoistRedundantVectorTransfers(cast<FuncOp>(op)); 67 if (lateCodegenStrategyOptions.enableHoistRedundantVectorTransfersOnTensor) 68 hoistRedundantVectorTransfersOnTensor(cast<FuncOp>(op)); 69 return success(); 70 }; 71 (void)linalg::applyStagedPatterns( 72 func, stage1Patterns, std::move(stage2Patterns), stage3Transforms); 73 74 //===--------------------------------------------------------------------===// 75 // Post staged patterns transforms 76 //===--------------------------------------------------------------------===// 77 78 // Programmatic splitting of slow/fast path vector transfers. 79 if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) { 80 RewritePatternSet patterns(context); 81 patterns.add<vector::VectorTransferFullPartialRewriter>( 82 context, vectorTransformsOptions); 83 (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); 84 } 85 86 // Programmatic controlled lowering of vector.contract only. 87 if (lateCodegenStrategyOptions.enableVectorContractLowering) { 88 RewritePatternSet vectorContractLoweringPatterns(context); 89 vectorContractLoweringPatterns 90 .add<ContractionOpToOuterProductOpLowering, 91 ContractionOpToMatmulOpLowering, ContractionOpLowering>( 92 vectorTransformsOptions, context); 93 vector::populateVectorTransferPermutationMapLoweringPatterns( 94 vectorContractLoweringPatterns); 95 (void)applyPatternsAndFoldGreedily( 96 func, std::move(vectorContractLoweringPatterns)); 97 } 98 99 // Programmatic controlled lowering of vector.transfer only. 100 if (lateCodegenStrategyOptions.enableVectorToSCFConversion) { 101 RewritePatternSet vectorToLoopsPatterns(context); 102 populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, 103 vectorToSCFOptions); 104 (void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns)); 105 } 106 107 // Ensure we drop the marker in the end. 108 func.walk([](LinalgOp op) { 109 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 110 }); 111 } 112