1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/SCF/Transforms.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassManager.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "mlir/Transforms/Passes.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 template <LinalgTilingLoopType LoopType> 25 static void fillFusionPatterns(MLIRContext *context, 26 const LinalgDependenceGraph &dependenceGraph, 27 RewritePatternSet &patterns) { 28 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 29 LinalgTileAndFusePattern<Conv2DOp>>( 30 context, dependenceGraph, 31 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 32 LinalgFusionOptions().setIndicesToFuse({2}), 33 LinalgTransformationFilter( 34 StringAttr::get(context, "basic_fusion"), 35 StringAttr::get(context, "after_basic_fusion")), 36 LinalgTransformationFilter( 37 ArrayRef<StringAttr>(), 38 StringAttr::get(context, "after_basic_fusion_producer")), 39 LinalgTransformationFilter( 40 ArrayRef<StringAttr>(), 41 StringAttr::get(context, "after_basic_fusion_original"))); 42 43 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 44 context, dependenceGraph, 45 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 46 LinalgFusionOptions().setIndicesToFuse({0}), 47 LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), 48 StringAttr::get(context, "after_lhs_fusion")), 49 LinalgTransformationFilter( 50 ArrayRef<StringAttr>(), 51 StringAttr::get(context, "after_lhs_fusion_producer")), 52 LinalgTransformationFilter( 53 ArrayRef<StringAttr>(), 54 StringAttr::get(context, "after_lhs_fusion_original"))); 55 56 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 57 context, dependenceGraph, 58 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 59 LinalgFusionOptions().setIndicesToFuse({2}), 60 LinalgTransformationFilter(StringAttr::get(context, "out_fusion"), 61 StringAttr::get(context, "after_out_fusion")), 62 LinalgTransformationFilter( 63 ArrayRef<StringAttr>(), 64 StringAttr::get(context, "after_out_fusion_producer")), 65 LinalgTransformationFilter( 66 ArrayRef<StringAttr>(), 67 StringAttr::get(context, "after_out_fusion_original"))); 68 69 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 70 context, dependenceGraph, 71 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 72 LinalgFusionOptions().setIndicesToFuse({1}), 73 LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), 74 StringAttr::get(context, "after_rhs_fusion")), 75 LinalgTransformationFilter( 76 ArrayRef<StringAttr>(), 77 StringAttr::get(context, "after_rhs_fusion_producer")), 78 LinalgTransformationFilter( 79 ArrayRef<StringAttr>(), 80 StringAttr::get(context, "after_rhs_fusion_original"))); 81 82 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 83 context, dependenceGraph, 84 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 85 LinalgFusionOptions().setIndicesToFuse({0, 2}), 86 LinalgTransformationFilter( 87 StringAttr::get(context, "two_operand_fusion"), 88 StringAttr::get(context, "after_two_operand_fusion")), 89 LinalgTransformationFilter( 90 ArrayRef<StringAttr>(), 91 StringAttr::get(context, "after_two_operand_fusion_producer")), 92 LinalgTransformationFilter( 93 ArrayRef<StringAttr>(), 94 StringAttr::get(context, "after_two_operand_fusion_original"))); 95 96 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 97 context, dependenceGraph, 98 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 99 LinalgFusionOptions().setIndicesToFuse({0, 1}), 100 LinalgTransformationFilter( 101 StringAttr::get(context, "transpose_fusion"), 102 StringAttr::get(context, "after_transpose_fusion")), 103 LinalgTransformationFilter( 104 ArrayRef<StringAttr>(), 105 StringAttr::get(context, "after_transpose_fusion_producer")), 106 LinalgTransformationFilter( 107 ArrayRef<StringAttr>(), 108 StringAttr::get(context, "after_transpose_fusion_original"))); 109 } 110 111 namespace { 112 template <LinalgTilingLoopType LoopType> 113 struct TestLinalgFusionTransforms 114 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> { 115 void getDependentDialects(DialectRegistry ®istry) const override { 116 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 117 scf::SCFDialect, StandardOpsDialect>(); 118 } 119 TestLinalgFusionTransforms() = default; 120 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 121 122 void runOnFunction() override { 123 MLIRContext *context = &this->getContext(); 124 FuncOp funcOp = this->getFunction(); 125 RewritePatternSet fusionPatterns(context); 126 Aliases alias; 127 LinalgDependenceGraph dependenceGraph = 128 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 129 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 130 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 131 } 132 }; 133 134 struct TestLinalgFusionTransformsParallelLoops 135 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 136 StringRef getArgument() const final { 137 return "test-linalg-fusion-transform-patterns"; 138 } 139 StringRef getDescription() const final { 140 return "Test Linalg fusion transformation patterns by applying them " 141 "greedily."; 142 } 143 }; 144 145 struct TestLinalgFusionTransformsLoops 146 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 147 StringRef getArgument() const final { 148 return "test-linalg-tensor-fusion-transform-patterns"; 149 } 150 StringRef getDescription() const final { 151 return "Test Linalg on tensor fusion transformation " 152 "patterns by applying them greedily."; 153 } 154 }; 155 156 struct TestLinalgFusionTransformsTiledLoops 157 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 158 StringRef getArgument() const final { 159 return "test-linalg-tiled-loop-fusion-transform-patterns"; 160 } 161 StringRef getDescription() const final { 162 return "Test Linalg on tensor fusion transformation " 163 "patterns by applying them greedily."; 164 } 165 }; 166 } // namespace 167 168 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 169 OpBuilder b(f); 170 DenseSet<Operation *> eraseSet; 171 172 // Save original Linalg ops, we only want to make a pass over those. 173 SmallVector<LinalgOp, 8> linalgOps; 174 f.walk([&](LinalgOp op) { 175 // TODO: support multi-results. 176 if (op->getNumResults() <= 1) 177 linalgOps.push_back(op); 178 }); 179 180 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 181 bool changed = false; 182 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 183 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 184 if (opOperand->get().getType().isa<MemRefType>()) { 185 // TODO: LinalgDependenceGraph should be able to update itself. 186 // The current naive and expensive reconstruction of the graph should be 187 // removed. 188 linalg::Aliases aliases; 189 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 190 auto info = fuseProducerOfBuffer(b, *opOperand, graph); 191 if (failed(info)) 192 continue; 193 auto *originalOp = info->originalProducer.getOperation(); 194 eraseSet.insert(originalOp); 195 auto *originalOpInLinalgOpsVector = 196 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 197 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 198 changed = true; 199 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 200 // Tile and Fuse tensor input. 201 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 202 continue; 203 auto info = fuseProducerOfTensor(b, *opOperand); 204 if (failed(info)) 205 continue; 206 auto *originalOp = info->originalProducer.getOperation(); 207 auto *originalOpInLinalgOpsVector = 208 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 209 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 210 // Don't mark for erasure in the tensor case, let DCE handle this. 211 changed = true; 212 } 213 } 214 } 215 // The `fuseProducerOfBuffer` function performs structural checks and in 216 // particular that no covering read or write exist between the consumer and 217 // the producer. As a consequence, the only fusions that may occur preserve 218 // subsequent dependences and are guaranteed by construction to produce the 219 // whole view. We may thus erase the producer once it is fused. 220 for (auto *e : eraseSet) 221 e->erase(); 222 223 return changed ? success() : failure(); 224 } 225 226 namespace { 227 struct TestLinalgGreedyFusion 228 : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> { 229 void getDependentDialects(DialectRegistry ®istry) const override { 230 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 231 scf::SCFDialect>(); 232 } 233 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 234 StringRef getDescription() const final { 235 return "Test Linalg fusion by applying a greedy test transformation."; 236 } 237 void runOnFunction() override { 238 MLIRContext *context = &getContext(); 239 RewritePatternSet patterns = 240 linalg::getLinalgTilingCanonicalizationPatterns(context); 241 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 242 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 243 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 244 do { 245 (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); 246 PassManager pm(context); 247 pm.addPass(createLoopInvariantCodeMotionPass()); 248 pm.addPass(createCanonicalizerPass()); 249 pm.addPass(createCSEPass()); 250 LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>()); 251 if (failed(res)) 252 this->signalPassFailure(); 253 } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); 254 } 255 }; 256 257 /// Pass to test tile and fuse of sequence of operations. Intended only for 258 /// testing. 259 struct TestLinalgTileAndFuseSequencePass 260 : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> { 261 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 262 StringRef getDescription() const final { 263 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 264 } 265 TestLinalgTileAndFuseSequencePass() = default; 266 TestLinalgTileAndFuseSequencePass( 267 const TestLinalgTileAndFuseSequencePass &pass){}; 268 269 ListOption<int64_t> tileSizes{ 270 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), 271 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 272 273 void getDependentDialects(DialectRegistry ®istry) const override { 274 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 275 scf::SCFDialect>(); 276 } 277 278 void runOnFunction() override { 279 FuncOp funcOp = getOperation(); 280 auto &blocks = funcOp.getBody().getBlocks(); 281 if (!llvm::hasSingleElement(blocks)) { 282 return; 283 } 284 SmallVector<LinalgOp, 2> linalgOps = 285 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 286 Aliases aliases; 287 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 288 OpBuilder builder(funcOp.getContext()); 289 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 290 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 291 return linalgOp.hasTensorSemantics(); 292 })) 293 loopType = LinalgTilingLoopType::Loops; 294 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 295 builder, linalgOps, dependenceGraph, 296 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 297 if (!tileAndFuseOps) 298 return signalPassFailure(); 299 if (linalgOps.back().hasTensorSemantics()) { 300 linalgOps.back().getOperation()->replaceAllUsesWith( 301 tileAndFuseOps->fusedLoops.front()); 302 } 303 for (auto op : linalgOps) 304 if (op.hasBufferSemantics()) 305 op.erase(); 306 } 307 }; 308 309 } // namespace 310 311 namespace mlir { 312 namespace test { 313 void registerTestLinalgFusionTransforms() { 314 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 315 } 316 void registerTestLinalgTensorFusionTransforms() { 317 PassRegistration<TestLinalgFusionTransformsLoops>(); 318 } 319 void registerTestLinalgTiledLoopFusionTransforms() { 320 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 321 } 322 void registerTestLinalgGreedyFusion() { 323 PassRegistration<TestLinalgGreedyFusion>(); 324 } 325 void registerTestLinalgTileAndFuseSequencePass() { 326 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 327 } 328 329 } // namespace test 330 } // namespace mlir 331