1 //===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as 10 // composable rewrite patterns through a programmable CodegenStrategy object. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" 15 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 18 #include "mlir/Dialect/SCF/Transforms.h" 19 #include "mlir/Dialect/Vector/VectorOps.h" 20 #include "mlir/Dialect/Vector/VectorTransforms.h" 21 #include "mlir/Pass/PassManager.h" 22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 23 #include "mlir/Transforms/LoopUtils.h" 24 #include "mlir/Transforms/Passes.h" 25 26 using namespace mlir; 27 using namespace mlir::linalg; 28 29 #define DEBUG_TYPE "linalg-codegen-strategy" 30 31 void mlir::linalg::CodegenStrategy::configurePassPipeline( 32 OpPassManager &pm, MLIRContext *context) const { 33 for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e; 34 ++stepCount) { 35 const std::unique_ptr<Transformation> &t = 36 transformationSequence[stepCount]; 37 std::string currentStr = std::to_string(stepCount); 38 auto currentState = Identifier::get(currentStr, context); 39 std::string nextStr = std::to_string(stepCount + 1); 40 auto nextState = Identifier::get(nextStr, context); 41 auto filter = (currentState.str() == std::to_string(0)) 42 ? linalg::LinalgTransformationFilter( 43 t->filter, ArrayRef<Identifier>{}, nextState) 44 : linalg::LinalgTransformationFilter( 45 t->filter, currentState, nextState); 46 t->addToPassPipeline(pm, filter); 47 pm.addPass(createLinalgStrategyEnablePass()); 48 } 49 LinalgVectorLoweringOptions vectorLoweringOptions; 50 vectorLoweringOptions.maxTransferRank = 51 lateCodegenStrategyOptions.maxTransferRank; 52 vectorLoweringOptions.enableVectorTransferLowering = 53 lateCodegenStrategyOptions.enableVectorTransferLowering; 54 vectorLoweringOptions.enableVectorTransferPartialRewrite = 55 lateCodegenStrategyOptions.enableVectorTransferPartialRewrite; 56 vectorLoweringOptions.enableVectorContractLowering = 57 lateCodegenStrategyOptions.enableVectorContractLowering; 58 vectorLoweringOptions.enableVectorToSCFConversion = 59 lateCodegenStrategyOptions.enableVectorToSCFConversion; 60 vectorLoweringOptions.vectorTransformOptions = vectorTransformOptions; 61 vectorLoweringOptions.vectorTransferToSCFOptions = vectorToSCFOptions; 62 pm.addPass(createLinalgStrategyLowerVectorsPass(vectorLoweringOptions)); 63 } 64 65 LogicalResult mlir::linalg::CodegenStrategy::transform(FuncOp funcOp) const { 66 PassManager pm(funcOp.getContext(), funcOp.getOperationName()); 67 configurePassPipeline(pm, funcOp.getContext()); 68 LogicalResult res = pm.run(funcOp); 69 // Ensure we drop the marker in the end. 70 funcOp.walk([](LinalgOp op) { 71 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 72 }); 73 return res; 74 } 75