1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/SCF/Transforms.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassManager.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "mlir/Transforms/Passes.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 template <LinalgTilingLoopType LoopType> 25 static void fillFusionPatterns(MLIRContext *context, 26 const LinalgDependenceGraph &dependenceGraph, 27 RewritePatternSet &patterns) { 28 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 29 LinalgTileAndFusePattern<Conv2DOp>>( 30 context, dependenceGraph, 31 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 32 LinalgFusionOptions().setIndicesToFuse({2}), 33 LinalgTransformationFilter( 34 StringAttr::get(context, "basic_fusion"), 35 StringAttr::get(context, "after_basic_fusion")), 36 LinalgTransformationFilter( 37 ArrayRef<StringAttr>(), 38 StringAttr::get(context, "after_basic_fusion_producer")), 39 LinalgTransformationFilter( 40 ArrayRef<StringAttr>(), 41 StringAttr::get(context, "after_basic_fusion_original"))); 42 43 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 44 context, dependenceGraph, 45 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 46 LinalgFusionOptions().setIndicesToFuse({0}), 47 LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), 48 StringAttr::get(context, "after_lhs_fusion")), 49 LinalgTransformationFilter( 50 ArrayRef<StringAttr>(), 51 StringAttr::get(context, "after_lhs_fusion_producer")), 52 LinalgTransformationFilter( 53 ArrayRef<StringAttr>(), 54 StringAttr::get(context, "after_lhs_fusion_original"))); 55 56 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 57 context, dependenceGraph, 58 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 59 LinalgFusionOptions().setIndicesToFuse({2}), 60 LinalgTransformationFilter(StringAttr::get(context, "out_fusion"), 61 StringAttr::get(context, "after_out_fusion")), 62 LinalgTransformationFilter( 63 ArrayRef<StringAttr>(), 64 StringAttr::get(context, "after_out_fusion_producer")), 65 LinalgTransformationFilter( 66 ArrayRef<StringAttr>(), 67 StringAttr::get(context, "after_out_fusion_original"))); 68 69 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 70 context, dependenceGraph, 71 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 72 LinalgFusionOptions().setIndicesToFuse({1}), 73 LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), 74 StringAttr::get(context, "after_rhs_fusion")), 75 LinalgTransformationFilter( 76 ArrayRef<StringAttr>(), 77 StringAttr::get(context, "after_rhs_fusion_producer")), 78 LinalgTransformationFilter( 79 ArrayRef<StringAttr>(), 80 StringAttr::get(context, "after_rhs_fusion_original"))); 81 82 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 83 context, dependenceGraph, 84 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 85 LinalgFusionOptions().setIndicesToFuse({0, 2}), 86 LinalgTransformationFilter( 87 StringAttr::get(context, "two_operand_fusion"), 88 StringAttr::get(context, "after_two_operand_fusion")), 89 LinalgTransformationFilter( 90 ArrayRef<StringAttr>(), 91 StringAttr::get(context, "after_two_operand_fusion_producer")), 92 LinalgTransformationFilter( 93 ArrayRef<StringAttr>(), 94 StringAttr::get(context, "after_two_operand_fusion_original"))); 95 96 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 97 context, dependenceGraph, 98 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 99 LinalgFusionOptions().setIndicesToFuse({0, 1}), 100 LinalgTransformationFilter( 101 StringAttr::get(context, "transpose_fusion"), 102 StringAttr::get(context, "after_transpose_fusion")), 103 LinalgTransformationFilter( 104 ArrayRef<StringAttr>(), 105 StringAttr::get(context, "after_transpose_fusion_producer")), 106 LinalgTransformationFilter( 107 ArrayRef<StringAttr>(), 108 StringAttr::get(context, "after_transpose_fusion_original"))); 109 } 110 111 namespace { 112 template <LinalgTilingLoopType LoopType> 113 struct TestLinalgFusionTransforms 114 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, 115 OperationPass<FuncOp>> { 116 void getDependentDialects(DialectRegistry ®istry) const override { 117 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 118 scf::SCFDialect>(); 119 } 120 TestLinalgFusionTransforms() = default; 121 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 122 123 void runOnOperation() override { 124 MLIRContext *context = &this->getContext(); 125 FuncOp funcOp = this->getOperation(); 126 RewritePatternSet fusionPatterns(context); 127 Aliases alias; 128 LinalgDependenceGraph dependenceGraph = 129 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 130 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 131 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 132 } 133 }; 134 135 struct TestLinalgFusionTransformsParallelLoops 136 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 137 StringRef getArgument() const final { 138 return "test-linalg-fusion-transform-patterns"; 139 } 140 StringRef getDescription() const final { 141 return "Test Linalg fusion transformation patterns by applying them " 142 "greedily."; 143 } 144 }; 145 146 struct TestLinalgFusionTransformsLoops 147 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 148 StringRef getArgument() const final { 149 return "test-linalg-tensor-fusion-transform-patterns"; 150 } 151 StringRef getDescription() const final { 152 return "Test Linalg on tensor fusion transformation " 153 "patterns by applying them greedily."; 154 } 155 }; 156 157 struct TestLinalgFusionTransformsTiledLoops 158 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 159 StringRef getArgument() const final { 160 return "test-linalg-tiled-loop-fusion-transform-patterns"; 161 } 162 StringRef getDescription() const final { 163 return "Test Linalg on tensor fusion transformation " 164 "patterns by applying them greedily."; 165 } 166 }; 167 } // namespace 168 169 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 170 OpBuilder b(f); 171 DenseSet<Operation *> eraseSet; 172 173 // Save original Linalg ops, we only want to make a pass over those. 174 SmallVector<LinalgOp, 8> linalgOps; 175 f.walk([&](LinalgOp op) { 176 // TODO: support multi-results. 177 if (op->getNumResults() <= 1) 178 linalgOps.push_back(op); 179 }); 180 181 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 182 bool changed = false; 183 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 184 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 185 if (opOperand->get().getType().isa<MemRefType>()) { 186 // TODO: LinalgDependenceGraph should be able to update itself. 187 // The current naive and expensive reconstruction of the graph should be 188 // removed. 189 linalg::Aliases aliases; 190 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 191 auto info = fuseProducerOfBuffer(b, *opOperand, graph); 192 if (failed(info)) 193 continue; 194 auto *originalOp = info->originalProducer.getOperation(); 195 eraseSet.insert(originalOp); 196 auto *originalOpInLinalgOpsVector = 197 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 198 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 199 changed = true; 200 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 201 // Tile and Fuse tensor input. 202 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 203 continue; 204 auto info = fuseProducerOfTensor(b, *opOperand); 205 if (failed(info)) 206 continue; 207 auto *originalOp = info->originalProducer.getOperation(); 208 auto *originalOpInLinalgOpsVector = 209 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 210 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 211 // Don't mark for erasure in the tensor case, let DCE handle this. 212 changed = true; 213 } 214 } 215 } 216 // The `fuseProducerOfBuffer` function performs structural checks and in 217 // particular that no covering read or write exist between the consumer and 218 // the producer. As a consequence, the only fusions that may occur preserve 219 // subsequent dependences and are guaranteed by construction to produce the 220 // whole view. We may thus erase the producer once it is fused. 221 for (auto *e : eraseSet) 222 e->erase(); 223 224 return changed ? success() : failure(); 225 } 226 227 namespace { 228 struct TestLinalgGreedyFusion 229 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<FuncOp>> { 230 void getDependentDialects(DialectRegistry ®istry) const override { 231 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 232 scf::SCFDialect>(); 233 } 234 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 235 StringRef getDescription() const final { 236 return "Test Linalg fusion by applying a greedy test transformation."; 237 } 238 void runOnOperation() override { 239 MLIRContext *context = &getContext(); 240 RewritePatternSet patterns = 241 linalg::getLinalgTilingCanonicalizationPatterns(context); 242 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 243 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 244 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 245 do { 246 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); 247 PassManager pm(context); 248 pm.addPass(createLoopInvariantCodeMotionPass()); 249 pm.addPass(createCanonicalizerPass()); 250 pm.addPass(createCSEPass()); 251 LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>()); 252 if (failed(res)) 253 this->signalPassFailure(); 254 } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 255 } 256 }; 257 258 /// Pass to test tile and fuse of sequence of operations. Intended only for 259 /// testing. 260 struct TestLinalgTileAndFuseSequencePass 261 : public PassWrapper<TestLinalgTileAndFuseSequencePass, 262 OperationPass<FuncOp>> { 263 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 264 StringRef getDescription() const final { 265 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 266 } 267 TestLinalgTileAndFuseSequencePass() = default; 268 TestLinalgTileAndFuseSequencePass( 269 const TestLinalgTileAndFuseSequencePass &pass) 270 : PassWrapper(pass){}; 271 272 ListOption<int64_t> tileSizes{ 273 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), 274 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 275 276 void getDependentDialects(DialectRegistry ®istry) const override { 277 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 278 scf::SCFDialect>(); 279 } 280 281 void runOnOperation() override { 282 FuncOp funcOp = getOperation(); 283 auto &blocks = funcOp.getBody().getBlocks(); 284 if (!llvm::hasSingleElement(blocks)) { 285 return; 286 } 287 SmallVector<LinalgOp, 2> linalgOps = 288 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 289 Aliases aliases; 290 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 291 OpBuilder builder(funcOp.getContext()); 292 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 293 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 294 return linalgOp.hasTensorSemantics(); 295 })) 296 loopType = LinalgTilingLoopType::Loops; 297 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 298 builder, linalgOps, dependenceGraph, 299 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 300 if (!tileAndFuseOps) 301 return signalPassFailure(); 302 if (linalgOps.back().hasTensorSemantics()) { 303 linalgOps.back().getOperation()->replaceAllUsesWith( 304 tileAndFuseOps->fusedLoops.front()); 305 } 306 for (auto op : linalgOps) 307 if (op.hasBufferSemantics()) 308 op.erase(); 309 } 310 }; 311 312 } // namespace 313 314 namespace mlir { 315 namespace test { 316 void registerTestLinalgFusionTransforms() { 317 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 318 } 319 void registerTestLinalgTensorFusionTransforms() { 320 PassRegistration<TestLinalgFusionTransformsLoops>(); 321 } 322 void registerTestLinalgTiledLoopFusionTransforms() { 323 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 324 } 325 void registerTestLinalgGreedyFusion() { 326 PassRegistration<TestLinalgGreedyFusion>(); 327 } 328 void registerTestLinalgTileAndFuseSequencePass() { 329 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 330 } 331 332 } // namespace test 333 } // namespace mlir 334