//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements logic for testing Linalg transformations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::linalg; namespace { struct TestLinalgTransforms : public PassWrapper { TestLinalgTransforms() = default; TestLinalgTransforms(const TestLinalgTransforms &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { // clang-format off registry.insert(); // clang-format on } StringRef getArgument() const final { return "test-linalg-transform-patterns"; } StringRef getDescription() const final { return "Test Linalg transformation patterns by applying them greedily."; } void runOnFunction() override; Option testPatterns{*this, "test-patterns", llvm::cl::desc("Test a mixed set of patterns"), llvm::cl::init(false)}; Option testMatmulToVectorPatterns1dTiling{ *this, "test-matmul-to-vector-patterns-tile-1d", llvm::cl::desc( "Test a fused pass that applies patterns from matmul to vectors via " "1-d tiling"), llvm::cl::init(false)}; Option testMatmulToVectorPatterns2dTiling{ *this, "test-matmul-to-vector-patterns-tile-2d", llvm::cl::desc( "Test a fused pass that applies patterns from matmul to vectors via " "2-d tiling"), llvm::cl::init(false)}; Option testPromotionOptions{*this, "test-linalg-promotion-options", llvm::cl::desc("Test promotion options"), llvm::cl::init(false)}; Option testTileAndDistributionOptions{ *this, "test-tile-and-distribute-options", llvm::cl::desc("Test tile and distribute options"), llvm::cl::init(false)}; Option testVectorTransferForwardingPatterns{ *this, "test-vector-transfer-forwarding-patterns", llvm::cl::desc( "Test a fused pass that forwards linalg.copy to vector.transfer"), llvm::cl::init(false)}; Option testGenericToVectorPattern{ *this, "test-linalg-to-vector-patterns", llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " "in vector.contract form"), llvm::cl::init(false)}; Option testTilePattern{*this, "test-tile-pattern", llvm::cl::desc("Test tile pattern"), llvm::cl::init(false)}; Option testTileScalarizeDynamicDims{ *this, "test-tile-scalarize-dynamic-dims", llvm::cl::desc("Test tiling of dynamic dims by 1"), llvm::cl::init(false)}; Option testTransformPadTensor{ *this, "test-transform-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; Option testGeneralizePadTensor{ *this, "test-generalize-pad-tensor", llvm::cl::desc("Test transform pad tensor by copying with generic ops"), llvm::cl::init(false)}; Option testSwapSubTensorPadTensor{ *this, "test-swap-subtensor-padtensor", llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption tileSizes{ *this, "tile-sizes", llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption testTiledLoopPeeling{ *this, "test-tiled-loop-peeling", llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; Option skipPartial{ *this, "skip-partial", llvm::cl::desc("Skip loops inside partial iterations during peeling"), llvm::cl::init(false)}; Option loopType{ *this, "loop-type", llvm::cl::desc("Specify the type of loops to generate: for, parallel or " "tiled_loop"), llvm::cl::init("for")}; }; } // end anonymous namespace static void applyPatterns(FuncOp funcOp) { MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); //===--------------------------------------------------------------------===// // Linalg tiling patterns. //===--------------------------------------------------------------------===// patterns.add>( ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), StringAttr::get(ctx, "L3"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), LinalgTransformationFilter(StringAttr::get(ctx, "L3"), StringAttr::get(ctx, "L2"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgTransformationFilter(StringAttr::get(ctx, "L2"), StringAttr::get(ctx, "L1"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), LinalgTransformationFilter(StringAttr::get(ctx, "L1"), StringAttr::get(ctx, "REG"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( LinalgTilingLoopType::ParallelLoops), LinalgTransformationFilter(ArrayRef{}, StringAttr::get(ctx, "L1"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes(8000), LinalgTransformationFilter( ArrayRef{StringAttr::get(ctx, "MEM"), StringAttr::get(ctx, "L3"), StringAttr::get(ctx, "L2")}, StringAttr::get(ctx, "REG"))); //===--------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===--------------------------------------------------------------------===// patterns.add>( ctx, LinalgTilingOptions() .setTileSizes({2000, 3000, 4000}) .setInterchange({1, 2, 0}), LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), StringAttr::get(ctx, "L2__with_perm__"))); patterns.add>( ctx, LinalgTilingOptions() .setTileSizes({200, 300, 400}) .setInterchange({1, 0, 2}), LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), StringAttr::get(ctx, "L1__with_perm__"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), StringAttr::get(ctx, "REG__with_perm__"))); patterns.add>( ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), StringAttr::get(ctx, "L1__with_perm__"))); patterns.add>( ctx, LinalgTilingOptions() .setTileSizes({16, 8, 4}) .setInterchange({1, 2, 0}) .setLoopType(LinalgTilingLoopType::ParallelLoops), LinalgTransformationFilter( StringAttr::get(ctx, "par__with_perm__"), StringAttr::get(ctx, "after_par__with_perm__"))); //===--------------------------------------------------------------------===// // Linalg to loops patterns. //===--------------------------------------------------------------------===// patterns.add>( ctx, /*loweringType=*/LinalgLoweringType::Loops, LinalgTransformationFilter(StringAttr::get(ctx, "REG"))); //===--------------------------------------------------------------------===// // Linalg distribution patterns. //===--------------------------------------------------------------------===// LinalgLoopDistributionOptions distributionOptions; //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// patterns.add( ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) .addOpFilter()); //===--------------------------------------------------------------------===// // Linalg generic interchange pattern. //===--------------------------------------------------------------------===// patterns.add( ctx, /*interchangeVector=*/ArrayRef{1, 2, 0}, LinalgTransformationFilter(ArrayRef{}, StringAttr::get(ctx, "PERMUTED"))); //===--------------------------------------------------------------------===// // Linalg subview operands promotion. //===--------------------------------------------------------------------===// patterns.add>( ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), LinalgTransformationFilter(StringAttr::get(ctx, "_promote_views_"), StringAttr::get(ctx, "_views_promoted_"))); patterns.add>( ctx, LinalgPromotionOptions() .setOperandsToPromote({0}) .setUseFullTileBuffersByDefault(true), LinalgTransformationFilter( StringAttr::get(ctx, "_promote_first_view_"), StringAttr::get(ctx, "_first_view_promoted_"))); patterns.add>( ctx, LinalgPromotionOptions() .setOperandsToPromote({1}) .setUseFullTileBuffers({false, true}) .setAlignment(32), LinalgTransformationFilter( StringAttr::get(ctx, "_promote_views_aligned_"), StringAttr::get(ctx, "_views_aligned_promoted_"))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker. funcOp.walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); } static void fillL1TilingAndMatmulToVectorPatterns( FuncOp funcOp, StringRef startMarker, SmallVectorImpl &patternsVector) { MLIRContext *ctx = funcOp.getContext(); patternsVector.emplace_back( ctx, std::make_unique>( ctx, LinalgTilingOptions() .setTileSizes({8, 12, 16}) .setInterchange({1, 0, 2}), LinalgTransformationFilter(StringAttr::get(ctx, startMarker), StringAttr::get(ctx, "L1")))); patternsVector.emplace_back( ctx, std::make_unique>( ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), LinalgTransformationFilter(StringAttr::get(ctx, "L1"), StringAttr::get(ctx, "VEC")))); patternsVector.emplace_back( ctx, std::make_unique( MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); patternsVector.back().add( ctx, LinalgTransformationFilter().addFilter( [](Operation *op) { return success(isa(op)); })); } //===----------------------------------------------------------------------===// // Test promotion callbacks //===----------------------------------------------------------------------===// // Allocation call back static Optional allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, ArrayRef boundingSubViewSize, DataLayout &layout) { SmallVector shape(boundingSubViewSize.size(), -1); return b .create( subView.getLoc(), MemRefType::get(shape, subView.getType().getElementType(), /*affineMapComposition =*/{}, 3), boundingSubViewSize) .getResult(); } // Deallocation callback static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { b.create(buffer.getLoc(), buffer); return success(); } // Copy in call back static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, bool isOutput) { auto floatType = src.getType().cast().getElementType(); if (!floatType.isa()) return failure(); if (!isOutput) { Value cst = b.create(src.getLoc(), FloatAttr::get(floatType, 42.0)); b.create(src.getLoc(), cst, dst); } b.create(src.getLoc(), src, dst); return success(); } static void fillPromotionCallBackPatterns(MLIRContext *ctx, RewritePatternSet &patterns) { patterns.add>( ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), LinalgTransformationFilter(StringAttr::get(ctx, "START"), StringAttr::get(ctx, "PROMOTE"))); patterns.add>( ctx, LinalgPromotionOptions() .setOperandsToPromote({0, 2}) .setUseFullTileBuffers({false, false}) .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) .setCopyInOutFns( [](OpBuilder &b, Value src, Value dst) -> LogicalResult { return copyCallBackFn(b, src, dst, false); }, [](OpBuilder &b, Value src, Value dst) -> LogicalResult { return copyCallBackFn(b, src, dst, true); }), LinalgTransformationFilter(StringAttr::get(ctx, "PROMOTE"))); } template static SmallVector getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges) { size_t count = std::min(3, parallelLoopRanges.size()); SmallVector procInfo(count); const char *xyz[] = {"x", "y", "z"}; Type indexType = b.getIndexType(); for (unsigned i = 0; i < count; ++i) { procInfo[count - 1 - i] = { b.create(loc, indexType, b.getStringAttr(xyz[i])), b.create(loc, indexType, b.getStringAttr(xyz[i]))}; } return procInfo; } static void fillTileAndDistributePatterns(MLIRContext *context, RewritePatternSet &patterns) { { LinalgLoopDistributionOptions cyclicNprocsEqNiters; cyclicNprocsEqNiters.distributionMethod.resize( 2, DistributionMethod::CyclicNumProcsEqNumIters); cyclicNprocsEqNiters.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsEqNiters), LinalgTransformationFilter( StringAttr::get(context, "distribute1"), StringAttr::get(context, "after_distribute1"))); } { LinalgLoopDistributionOptions cyclicNprocsGeNiters; cyclicNprocsGeNiters.distributionMethod.resize( 2, DistributionMethod::CyclicNumProcsGeNumIters); cyclicNprocsGeNiters.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsGeNiters), LinalgTransformationFilter( StringAttr::get(context, "distribute2"), StringAttr::get(context, "after_distribute2"))); } { LinalgLoopDistributionOptions cyclicNprocsDefault; cyclicNprocsDefault.distributionMethod.resize(2, DistributionMethod::Cyclic); cyclicNprocsDefault.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsDefault), LinalgTransformationFilter( StringAttr::get(context, "distribute3"), StringAttr::get(context, "after_distribute3"))); } { LinalgLoopDistributionOptions cyclicNprocsMixed1; cyclicNprocsMixed1.distributionMethod = { DistributionMethod::CyclicNumProcsEqNumIters, DistributionMethod::CyclicNumProcsGeNumIters}; cyclicNprocsMixed1.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsMixed1), LinalgTransformationFilter( StringAttr::get(context, "distribute4"), StringAttr::get(context, "after_distribute4"))); } { LinalgLoopDistributionOptions cyclicNprocsMixed2; cyclicNprocsMixed2.distributionMethod = { DistributionMethod::CyclicNumProcsGeNumIters, DistributionMethod::Cyclic}; cyclicNprocsMixed2.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsMixed2), LinalgTransformationFilter( StringAttr::get(context, "distribute5"), StringAttr::get(context, "after_distribute5"))); } { LinalgLoopDistributionOptions cyclicNprocsMixed3; cyclicNprocsMixed3.distributionMethod = { DistributionMethod::Cyclic, DistributionMethod::CyclicNumProcsEqNumIters}; cyclicNprocsMixed3.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) .setDistributionOptions(cyclicNprocsMixed3), LinalgTransformationFilter( StringAttr::get(context, "distribute6"), StringAttr::get(context, "after_distribute6"))); } { LinalgLoopDistributionOptions cyclicNprocsEqNiters; cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); cyclicNprocsEqNiters.procInfo = getGpuProcIds; patterns.add>( context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::Loops) .setDistributionOptions(cyclicNprocsEqNiters), LinalgTransformationFilter( StringAttr::get(context, "tensors_distribute1"), StringAttr::get(context, "tensors_after_distribute1"))); } } static void applyMatmulToVectorPatterns(FuncOp funcOp, bool testMatmulToVectorPatterns1dTiling, bool testMatmulToVectorPatterns2dTiling) { MLIRContext *ctx = funcOp.getContext(); SmallVector stage1Patterns; if (testMatmulToVectorPatterns1dTiling) { fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { stage1Patterns.emplace_back( ctx, std::make_unique>( ctx, LinalgTilingOptions() .setTileSizes({768, 264, 768}) .setInterchange({1, 2, 0}), LinalgTransformationFilter(StringAttr::get(ctx, "START"), StringAttr::get(ctx, "L2")))); fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); } { // Canonicalization patterns RewritePatternSet canonicalizationPatterns(funcOp.getContext()); vector::populateVectorTransferPermutationMapLoweringPatterns( canonicalizationPatterns); vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); stage1Patterns.push_back(std::move(canonicalizationPatterns)); } SmallVector frozenStage1Patterns; llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); FrozenRewritePatternSet stage2Patterns = getLinalgTilingCanonicalizationPatterns(ctx); (void)applyStagedPatterns(funcOp, frozenStage1Patterns, std::move(stage2Patterns)); } static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { RewritePatternSet forwardPattern(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); } static void applyLinalgToVectorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add( funcOp.getContext(), LinalgTransformationFilter() .addOpFilter()); populatePadTensorOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyPadTensorToGenericPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } static void applyTilePattern(FuncOp funcOp, std::string loopType, ArrayRef tileSizes, ArrayRef peeledLoops, bool scalarizeDynamicDims) { MLIRContext *context = funcOp.getContext(); RewritePatternSet tilingPattern(context); LinalgTilingLoopType type = llvm::StringSwitch(loopType) .Case("for", LinalgTilingLoopType::Loops) .Case("affine", LinalgTilingLoopType::AffineLoops) .Case("parallel", LinalgTilingLoopType::ParallelLoops) .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); auto linalgTilingOptions = linalg::LinalgTilingOptions() .setPeeledLoops(peeledLoops) .setLoopType(type); if (scalarizeDynamicDims) { linalgTilingOptions.scalarizeDynamicDims(); assert(tileSizes.empty() && "tileSizes and scalarizeDynamicDims is mutually exclusive"); } else { linalgTilingOptions.setTileSizes(tileSizes); } tilingPattern.add, linalg::LinalgTilingPattern>( context, linalgTilingOptions, linalg::LinalgTransformationFilter(StringAttr::get(context, "tile"))); (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; namespace { /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the /// `idx`-th loop contains only "full" iterations and a second loop for the /// remaining partial iteration (if any). struct TiledLoopPeelingPattern : public OpRewritePattern { TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) : OpRewritePattern(ctx), idx(idx), skipPartial(skipPartial) { } LogicalResult matchAndRewrite(TiledLoopOp loopOp, PatternRewriter &rewriter) const override { SmallVector peeledLoops; if (loopOp->hasAttr(kPeeledLoopsLabel)) { auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast(); peeledLoops = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { return attr.cast().getInt(); })); // Check if the loop was already peeled. if (llvm::find(peeledLoops, idx) != peeledLoops.end()) return failure(); } if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) // No peeling of loop nests with a partial iteration. return failure(); if (static_cast(loopOp.iterator_types().size()) <= idx) return failure(); // Peel loop and canonicalize. TiledLoopOp result; if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, result))) return failure(); // Apply label, so that the same loop is not rewritten a second time. peeledLoops.push_back(idx); rewriter.updateRootInPlace(loopOp, [&]() { loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); }); result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); return success(); } /// Index of loop to peel. int64_t idx; /// If set to true, do not peel TiledLoopOps with a partial iteration. bool skipPartial; }; } // namespace static void applyTiledLoopPeelingPattern(FuncOp funcOp, ArrayRef loops, bool skipPartial) { MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); for (unsigned idx : loops) patterns.add(ctx, idx, skipPartial); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the markers. funcOp.walk([](TiledLoopOp op) { op->removeAttr(kPeeledLoopsLabel); op->removeAttr(kPartialIterationLabel); }); } /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { auto lambda = [&](void *) { getFunction().walk([](LinalgOp op) { op->removeAttr(LinalgTransforms::kLinalgTransformMarker); }); }; std::unique_ptr cleanupGuard{(void *)1, lambda}; if (testPromotionOptions) { RewritePatternSet patterns(&getContext()); fillPromotionCallBackPatterns(&getContext(), patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testTileAndDistributionOptions) { RewritePatternSet patterns(&getContext()); fillTileAndDistributePatterns(&getContext(), patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); return; } if (testPatterns) return applyPatterns(getFunction()); if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) return applyMatmulToVectorPatterns(getFunction(), testMatmulToVectorPatterns1dTiling, testMatmulToVectorPatterns2dTiling); if (testVectorTransferForwardingPatterns) return applyVectorTransferForwardingPatterns(getFunction()); if (testGenericToVectorPattern) return applyLinalgToVectorPatterns(getFunction()); if (testTransformPadTensor) return applyPadTensorToGenericPatterns(getFunction()); if (testGeneralizePadTensor) return applyGeneralizePadTensorPatterns(getFunction()); if (testSwapSubTensorPadTensor) return applyExtractSliceOfPadTensorSwapPattern(getFunction()); if (testTiledLoopPeeling.hasValue()) return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, skipPartial); if (testTilePattern) return applyTilePattern(getFunction(), loopType, tileSizes, peeledLoops, /*scalarizeDynamicDims=*/false); if (testTileScalarizeDynamicDims) return applyTilePattern(getFunction(), loopType, tileSizes, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); } namespace mlir { namespace test { void registerTestLinalgTransforms() { PassRegistration(); } } // namespace test } // namespace mlir