1 //===- CodegenStrategy.cpp - Linalg programmable codegen strategy ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as 10 // composable rewrite patterns through a programmable CodegenStrategy object. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 20 #include "mlir/Pass/PassManager.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "mlir/Transforms/Passes.h" 23 24 using namespace mlir; 25 using namespace mlir::linalg; 26 27 #define DEBUG_TYPE "linalg-codegen-strategy" 28 29 void mlir::linalg::CodegenStrategy::configurePassPipeline( 30 OpPassManager &pm, MLIRContext *context, bool addEnablePass) const { 31 for (unsigned stepCount = 0, e = transformationSequence.size(); stepCount < e; 32 ++stepCount) { 33 const std::unique_ptr<Transformation> &t = 34 transformationSequence[stepCount]; 35 std::string currentStr = std::to_string(stepCount); 36 auto currentState = StringAttr::get(context, currentStr); 37 std::string nextStr = std::to_string(stepCount + 1); 38 auto nextState = StringAttr::get(context, nextStr); 39 auto filter = (currentState.str() == std::to_string(0)) 40 ? linalg::LinalgTransformationFilter( 41 t->filter, ArrayRef<StringAttr>{}, nextState) 42 : linalg::LinalgTransformationFilter( 43 t->filter, currentState, nextState); 44 t->addToPassPipeline(pm, filter); 45 if (addEnablePass) 46 pm.addPass(createLinalgStrategyEnablePass(linalgEnablingOptions)); 47 } 48 pm.addPass(createLinalgStrategyRemoveMarkersPass()); 49 } 50