1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/Passes.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 /// Use this to safely fill patterns for this test, since RewritePatternSet::add 27 /// forwards Rvalues only to the first pattern. 28 template <typename OpTy, LinalgTilingLoopType LoopType> 29 static void fillFusionPattern(MLIRContext *context, 30 const LinalgDependenceGraph &dependenceGraph, 31 RewritePatternSet &patterns, 32 const Twine &testCase, 33 ArrayRef<int64_t> tileSizes, 34 ArrayRef<int64_t> indicesToFuse) { 35 patterns.add<LinalgTileAndFusePattern<OpTy>>( 36 context, dependenceGraph, 37 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType), 38 LinalgFusionOptions().setIndicesToFuse(indicesToFuse), 39 LinalgTransformationFilter( 40 StringAttr::get(context, testCase + "_fusion"), 41 StringAttr::get(context, "after_" + testCase + "_fusion")), 42 LinalgTransformationFilter( 43 ArrayRef<StringAttr>(), 44 StringAttr::get(context, "after_" + testCase + "_fusion_producer")), 45 LinalgTransformationFilter( 46 ArrayRef<StringAttr>(), 47 StringAttr::get(context, "after_" + testCase + "_fusion_original"))); 48 } 49 50 template <LinalgTilingLoopType LoopType> 51 static void fillFusionPatterns(MLIRContext *context, 52 const LinalgDependenceGraph &dependenceGraph, 53 RewritePatternSet &patterns) { 54 fillFusionPattern<Conv2DOp, LoopType>(context, dependenceGraph, patterns, 55 /*testCase=*/"basic", 56 /*tileSizes=*/{32, 64, 16}, 57 /*indicesToFuse=*/{2}); 58 59 auto fillMatmulPattern = [&](const Twine &testCase, 60 ArrayRef<int64_t> indicesToFuse) { 61 fillFusionPattern<MatmulOp, LoopType>(context, dependenceGraph, patterns, 62 testCase, /*tileSizes=*/{32, 64, 16}, 63 indicesToFuse); 64 }; 65 fillMatmulPattern(/*testCase=*/"basic", 66 /*indicesToFuse=*/{2}); 67 fillMatmulPattern(/*testCase=*/"lhs", 68 /*indicesToFuse=*/{0}); 69 fillMatmulPattern(/*testCase=*/"out", 70 /*indicesToFuse=*/{2}); 71 fillMatmulPattern(/*testCase=*/"rhs", 72 /*indicesToFuse=*/{1}); 73 fillMatmulPattern(/*testCase=*/"two_operand", 74 /*indicesToFuse=*/{0, 2}); 75 76 fillFusionPattern<GenericOp, LoopType>(context, dependenceGraph, patterns, 77 /*testCase=*/"transpose", 78 /*tileSizes=*/{32, 64}, 79 /*indicesToFuse=*/{0, 1}); 80 } 81 82 namespace { 83 template <LinalgTilingLoopType LoopType> 84 struct TestLinalgFusionTransforms 85 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, 86 OperationPass<func::FuncOp>> { 87 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms) 88 89 void getDependentDialects(DialectRegistry ®istry) const override { 90 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 91 scf::SCFDialect>(); 92 } 93 TestLinalgFusionTransforms() = default; 94 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 95 96 void runOnOperation() override { 97 MLIRContext *context = &this->getContext(); 98 func::FuncOp funcOp = this->getOperation(); 99 RewritePatternSet fusionPatterns(context); 100 Aliases alias; 101 LinalgDependenceGraph dependenceGraph = 102 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 103 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 104 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 105 } 106 }; 107 108 struct TestLinalgFusionTransformsParallelLoops 109 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 110 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 111 TestLinalgFusionTransformsParallelLoops) 112 113 StringRef getArgument() const final { 114 return "test-linalg-fusion-transform-patterns"; 115 } 116 StringRef getDescription() const final { 117 return "Test Linalg fusion transformation patterns by applying them " 118 "greedily."; 119 } 120 }; 121 122 struct TestLinalgFusionTransformsLoops 123 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 124 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops) 125 126 StringRef getArgument() const final { 127 return "test-linalg-tensor-fusion-transform-patterns"; 128 } 129 StringRef getDescription() const final { 130 return "Test Linalg on tensor fusion transformation " 131 "patterns by applying them greedily."; 132 } 133 }; 134 135 struct TestLinalgFusionTransformsTiledLoops 136 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 137 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 138 TestLinalgFusionTransformsTiledLoops) 139 140 StringRef getArgument() const final { 141 return "test-linalg-tiled-loop-fusion-transform-patterns"; 142 } 143 StringRef getDescription() const final { 144 return "Test Linalg on tensor fusion transformation " 145 "patterns by applying them greedily."; 146 } 147 }; 148 } // namespace 149 150 static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { 151 OpBuilder b(f); 152 DenseSet<Operation *> eraseSet; 153 154 // Save original Linalg ops, we only want to make a pass over those. 155 SmallVector<LinalgOp, 8> linalgOps; 156 f.walk([&](LinalgOp op) { 157 // TODO: support multi-results. 158 if (op->getNumResults() <= 1) 159 linalgOps.push_back(op); 160 }); 161 162 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 163 bool changed = false; 164 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 165 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 166 if (opOperand->get().getType().isa<MemRefType>()) { 167 // TODO: LinalgDependenceGraph should be able to update itself. 168 // The current naive and expensive reconstruction of the graph should be 169 // removed. 170 linalg::Aliases aliases; 171 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 172 auto info = fuseProducerOfBuffer(b, *opOperand, graph); 173 if (failed(info)) 174 continue; 175 auto *originalOp = info->originalProducer.getOperation(); 176 eraseSet.insert(originalOp); 177 auto *originalOpInLinalgOpsVector = 178 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 179 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 180 changed = true; 181 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 182 // Tile and Fuse tensor input. 183 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 184 continue; 185 auto info = fuseProducerOfTensor(b, *opOperand); 186 if (failed(info)) 187 continue; 188 auto *originalOp = info->originalProducer.getOperation(); 189 auto *originalOpInLinalgOpsVector = 190 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 191 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 192 // Don't mark for erasure in the tensor case, let DCE handle this. 193 changed = true; 194 } 195 } 196 } 197 // The `fuseProducerOfBuffer` function performs structural checks and in 198 // particular that no covering read or write exist between the consumer and 199 // the producer. As a consequence, the only fusions that may occur preserve 200 // subsequent dependences and are guaranteed by construction to produce the 201 // whole view. We may thus erase the producer once it is fused. 202 for (auto *e : eraseSet) 203 e->erase(); 204 205 return changed ? success() : failure(); 206 } 207 208 namespace { 209 struct TestLinalgGreedyFusion 210 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> { 211 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion) 212 213 void getDependentDialects(DialectRegistry ®istry) const override { 214 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 215 scf::SCFDialect>(); 216 } 217 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 218 StringRef getDescription() const final { 219 return "Test Linalg fusion by applying a greedy test transformation."; 220 } 221 void runOnOperation() override { 222 MLIRContext *context = &getContext(); 223 RewritePatternSet patterns = 224 linalg::getLinalgTilingCanonicalizationPatterns(context); 225 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 226 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 227 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 228 OpPassManager pm(func::FuncOp::getOperationName()); 229 pm.addPass(createLoopInvariantCodeMotionPass()); 230 pm.addPass(createCanonicalizerPass()); 231 pm.addPass(createCSEPass()); 232 do { 233 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); 234 if (failed(runPipeline(pm, getOperation()))) 235 this->signalPassFailure(); 236 } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 237 } 238 }; 239 240 /// Pass to test tile and fuse of sequence of operations. Intended only for 241 /// testing. 242 struct TestLinalgTileAndFuseSequencePass 243 : public PassWrapper<TestLinalgTileAndFuseSequencePass, 244 OperationPass<func::FuncOp>> { 245 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 246 TestLinalgTileAndFuseSequencePass) 247 248 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 249 StringRef getDescription() const final { 250 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 251 } 252 TestLinalgTileAndFuseSequencePass() = default; 253 TestLinalgTileAndFuseSequencePass( 254 const TestLinalgTileAndFuseSequencePass &pass) 255 : PassWrapper(pass){}; 256 257 ListOption<int64_t> tileSizes{*this, "tile-sizes", 258 llvm::cl::desc("Tile sizes to use for ops")}; 259 260 void getDependentDialects(DialectRegistry ®istry) const override { 261 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 262 scf::SCFDialect>(); 263 } 264 265 void runOnOperation() override { 266 func::FuncOp funcOp = getOperation(); 267 auto &blocks = funcOp.getBody().getBlocks(); 268 if (!llvm::hasSingleElement(blocks)) { 269 return; 270 } 271 SmallVector<LinalgOp, 2> linalgOps = 272 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 273 Aliases aliases; 274 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 275 OpBuilder builder(funcOp.getContext()); 276 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 277 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 278 return linalgOp.hasTensorSemantics(); 279 })) 280 loopType = LinalgTilingLoopType::Loops; 281 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 282 builder, linalgOps, dependenceGraph, 283 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 284 if (!tileAndFuseOps) 285 return signalPassFailure(); 286 if (linalgOps.back().hasTensorSemantics()) { 287 linalgOps.back().getOperation()->replaceAllUsesWith( 288 tileAndFuseOps->fusedLoops.front()); 289 } 290 for (auto op : linalgOps) 291 if (op.hasBufferSemantics()) 292 op.erase(); 293 } 294 }; 295 296 } // namespace 297 298 namespace mlir { 299 namespace test { 300 void registerTestLinalgFusionTransforms() { 301 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 302 } 303 void registerTestLinalgTensorFusionTransforms() { 304 PassRegistration<TestLinalgFusionTransformsLoops>(); 305 } 306 void registerTestLinalgTiledLoopFusionTransforms() { 307 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 308 } 309 void registerTestLinalgGreedyFusion() { 310 PassRegistration<TestLinalgGreedyFusion>(); 311 } 312 void registerTestLinalgTileAndFuseSequencePass() { 313 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 314 } 315 316 } // namespace test 317 } // namespace mlir 318