1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/SCF/Transforms.h" 16 #include "mlir/Pass/Pass.h" 17 #include "mlir/Pass/PassManager.h" 18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 19 #include "mlir/Transforms/Passes.h" 20 21 using namespace mlir; 22 using namespace mlir::linalg; 23 24 template <LinalgTilingLoopType LoopType> 25 static void fillFusionPatterns(MLIRContext *context, 26 const LinalgDependenceGraph &dependenceGraph, 27 RewritePatternSet &patterns) { 28 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 29 LinalgTileAndFusePattern<Conv2DOp>>( 30 context, dependenceGraph, 31 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 32 LinalgFusionOptions().setIndicesToFuse({2}), 33 LinalgTransformationFilter( 34 Identifier::get("basic_fusion", context), 35 Identifier::get("after_basic_fusion", context)), 36 LinalgTransformationFilter( 37 ArrayRef<Identifier>(), 38 Identifier::get("after_basic_fusion_producer", context)), 39 LinalgTransformationFilter( 40 ArrayRef<Identifier>(), 41 Identifier::get("after_basic_fusion_original", context))); 42 43 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 44 context, dependenceGraph, 45 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 46 LinalgFusionOptions().setIndicesToFuse({0}), 47 LinalgTransformationFilter(Identifier::get("lhs_fusion", context), 48 Identifier::get("after_lhs_fusion", context)), 49 LinalgTransformationFilter( 50 ArrayRef<Identifier>(), 51 Identifier::get("after_lhs_fusion_producer", context)), 52 LinalgTransformationFilter( 53 ArrayRef<Identifier>(), 54 Identifier::get("after_lhs_fusion_original", context))); 55 56 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 57 context, dependenceGraph, 58 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 59 LinalgFusionOptions().setIndicesToFuse({2}), 60 LinalgTransformationFilter(Identifier::get("out_fusion", context), 61 Identifier::get("after_out_fusion", context)), 62 LinalgTransformationFilter( 63 ArrayRef<Identifier>(), 64 Identifier::get("after_out_fusion_producer", context)), 65 LinalgTransformationFilter( 66 ArrayRef<Identifier>(), 67 Identifier::get("after_out_fusion_original", context))); 68 69 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 70 context, dependenceGraph, 71 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 72 LinalgFusionOptions().setIndicesToFuse({1}), 73 LinalgTransformationFilter(Identifier::get("rhs_fusion", context), 74 Identifier::get("after_rhs_fusion", context)), 75 LinalgTransformationFilter( 76 ArrayRef<Identifier>(), 77 Identifier::get("after_rhs_fusion_producer", context)), 78 LinalgTransformationFilter( 79 ArrayRef<Identifier>(), 80 Identifier::get("after_rhs_fusion_original", context))); 81 82 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 83 context, dependenceGraph, 84 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 85 LinalgFusionOptions().setIndicesToFuse({0, 2}), 86 LinalgTransformationFilter( 87 Identifier::get("two_operand_fusion", context), 88 Identifier::get("after_two_operand_fusion", context)), 89 LinalgTransformationFilter( 90 ArrayRef<Identifier>(), 91 Identifier::get("after_two_operand_fusion_producer", context)), 92 LinalgTransformationFilter( 93 ArrayRef<Identifier>(), 94 Identifier::get("after_two_operand_fusion_original", context))); 95 96 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 97 context, dependenceGraph, 98 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 99 LinalgFusionOptions().setIndicesToFuse({0, 1}), 100 LinalgTransformationFilter( 101 Identifier::get("transpose_fusion", context), 102 Identifier::get("after_transpose_fusion", context)), 103 LinalgTransformationFilter( 104 ArrayRef<Identifier>(), 105 Identifier::get("after_transpose_fusion_producer", context)), 106 LinalgTransformationFilter( 107 ArrayRef<Identifier>(), 108 Identifier::get("after_transpose_fusion_original", context))); 109 } 110 111 namespace { 112 template <LinalgTilingLoopType LoopType> 113 struct TestLinalgFusionTransforms 114 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> { 115 void getDependentDialects(DialectRegistry ®istry) const override { 116 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 117 scf::SCFDialect, StandardOpsDialect>(); 118 } 119 TestLinalgFusionTransforms() = default; 120 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 121 122 void runOnFunction() override { 123 MLIRContext *context = &this->getContext(); 124 FuncOp funcOp = this->getFunction(); 125 RewritePatternSet fusionPatterns(context); 126 Aliases alias; 127 LinalgDependenceGraph dependenceGraph = 128 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 129 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 130 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 131 } 132 }; 133 134 struct TestLinalgFusionTransformsParallelLoops 135 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 136 StringRef getArgument() const final { 137 return "test-linalg-fusion-transform-patterns"; 138 } 139 StringRef getDescription() const final { 140 return "Test Linalg fusion transformation patterns by applying them " 141 "greedily."; 142 } 143 }; 144 145 struct TestLinalgFusionTransformsLoops 146 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 147 StringRef getArgument() const final { 148 return "test-linalg-tensor-fusion-transform-patterns"; 149 } 150 StringRef getDescription() const final { 151 return "Test Linalg on tensor fusion transformation " 152 "patterns by applying them greedily."; 153 } 154 }; 155 156 struct TestLinalgFusionTransformsTiledLoops 157 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 158 StringRef getArgument() const final { 159 return "test-linalg-tiled-loop-fusion-transform-patterns"; 160 } 161 StringRef getDescription() const final { 162 return "Test Linalg on tensor fusion transformation " 163 "patterns by applying them greedily."; 164 } 165 }; 166 } // namespace 167 168 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 169 OpBuilder b(f); 170 DenseSet<Operation *> eraseSet; 171 172 // Save original Linalg ops, we only want to make a pass over those. 173 SmallVector<LinalgOp, 8> linalgOps; 174 f.walk([&](LinalgOp op) { 175 // TODO: support multi-results. 176 if (op->getNumResults() <= 1) 177 linalgOps.push_back(op); 178 }); 179 180 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 181 bool changed = false; 182 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 183 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 184 if (opOperand->get().getType().isa<MemRefType>()) { 185 // TODO: LinalgDependenceGraph should be able to update itself. 186 // The current naive and expensive reconstruction of the graph should be 187 // removed. 188 linalg::Aliases aliases; 189 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 190 if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) { 191 auto *originalOp = info->originalProducer.getOperation(); 192 eraseSet.insert(originalOp); 193 auto *originalOpInLinalgOpsVector = 194 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 195 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 196 changed = true; 197 } 198 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 199 // Tile and Fuse tensor input. 200 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 201 continue; 202 if (auto info = fuseProducerOfTensor(b, *opOperand)) { 203 auto *originalOp = info->originalProducer.getOperation(); 204 auto *originalOpInLinalgOpsVector = 205 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 206 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 207 // Don't mark for erasure in the tensor case, let DCE handle this. 208 changed = true; 209 } 210 } 211 } 212 } 213 // The `fuseProducerOfBuffer` function performs structural checks and in 214 // particular that no covering read or write exist between the consumer and 215 // the producer. As a consequence, the only fusions that may occur preserve 216 // subsequent dependences and are guaranteed by construction to produce the 217 // whole view. We may thus erase the producer once it is fused. 218 for (auto *e : eraseSet) 219 e->erase(); 220 221 return changed ? success() : failure(); 222 } 223 224 namespace { 225 struct TestLinalgGreedyFusion 226 : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> { 227 void getDependentDialects(DialectRegistry ®istry) const override { 228 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 229 scf::SCFDialect>(); 230 } 231 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 232 StringRef getDescription() const final { 233 return "Test Linalg fusion by applying a greedy test transformation."; 234 } 235 void runOnFunction() override { 236 MLIRContext *context = &getContext(); 237 RewritePatternSet patterns = 238 linalg::getLinalgTilingCanonicalizationPatterns(context); 239 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 240 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 241 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 242 do { 243 (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); 244 PassManager pm(context); 245 pm.addPass(createLoopInvariantCodeMotionPass()); 246 pm.addPass(createCanonicalizerPass()); 247 pm.addPass(createCSEPass()); 248 LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>()); 249 if (failed(res)) 250 this->signalPassFailure(); 251 } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); 252 } 253 }; 254 255 /// Pass to test tile and fuse of sequence of operations. Intended only for 256 /// testing. 257 struct TestLinalgTileAndFuseSequencePass 258 : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> { 259 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 260 StringRef getDescription() const final { 261 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 262 } 263 TestLinalgTileAndFuseSequencePass() = default; 264 TestLinalgTileAndFuseSequencePass( 265 const TestLinalgTileAndFuseSequencePass &pass){}; 266 267 ListOption<int64_t> tileSizes{ 268 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), 269 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 270 271 void getDependentDialects(DialectRegistry ®istry) const override { 272 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 273 scf::SCFDialect>(); 274 } 275 276 void runOnFunction() override { 277 FuncOp funcOp = getOperation(); 278 auto &blocks = funcOp.getBody().getBlocks(); 279 if (!llvm::hasSingleElement(blocks)) { 280 return; 281 } 282 SmallVector<LinalgOp, 2> linalgOps = 283 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 284 Aliases aliases; 285 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 286 OpBuilder builder(funcOp.getContext()); 287 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 288 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 289 return linalgOp.hasTensorSemantics(); 290 })) 291 loopType = LinalgTilingLoopType::Loops; 292 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 293 builder, linalgOps, dependenceGraph, 294 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 295 if (!tileAndFuseOps) 296 return signalPassFailure(); 297 if (linalgOps.back().hasTensorSemantics()) { 298 linalgOps.back().getOperation()->replaceAllUsesWith( 299 tileAndFuseOps->fusedLoops.front()); 300 } 301 for (auto op : linalgOps) 302 if (op.hasBufferSemantics()) 303 op.erase(); 304 } 305 }; 306 307 } // namespace 308 309 namespace mlir { 310 namespace test { 311 void registerTestLinalgFusionTransforms() { 312 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 313 } 314 void registerTestLinalgTensorFusionTransforms() { 315 PassRegistration<TestLinalgFusionTransformsLoops>(); 316 } 317 void registerTestLinalgTiledLoopFusionTransforms() { 318 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 319 } 320 void registerTestLinalgGreedyFusion() { 321 PassRegistration<TestLinalgGreedyFusion>(); 322 } 323 void registerTestLinalgTileAndFuseSequencePass() { 324 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 325 } 326 327 } // namespace test 328 } // namespace mlir 329