1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "mlir/Transforms/Passes.h" 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 23 template <LinalgTilingLoopType LoopType> 24 static void fillFusionPatterns(MLIRContext *context, 25 const LinalgDependenceGraph &dependenceGraph, 26 RewritePatternSet &patterns) { 27 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 28 LinalgTileAndFusePattern<ConvOp>>( 29 context, dependenceGraph, 30 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 31 LinalgFusionOptions().setIndicesToFuse({2}), 32 LinalgTransformationFilter( 33 Identifier::get("basic_fusion", context), 34 Identifier::get("after_basic_fusion", context)), 35 LinalgTransformationFilter( 36 ArrayRef<Identifier>(), 37 Identifier::get("after_basic_fusion_producer", context)), 38 LinalgTransformationFilter( 39 ArrayRef<Identifier>(), 40 Identifier::get("after_basic_fusion_original", context))); 41 42 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 43 context, dependenceGraph, 44 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 45 LinalgFusionOptions().setIndicesToFuse({0}), 46 LinalgTransformationFilter(Identifier::get("lhs_fusion", context), 47 Identifier::get("after_lhs_fusion", context)), 48 LinalgTransformationFilter( 49 ArrayRef<Identifier>(), 50 Identifier::get("after_lhs_fusion_producer", context)), 51 LinalgTransformationFilter( 52 ArrayRef<Identifier>(), 53 Identifier::get("after_lhs_fusion_original", context))); 54 55 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 56 context, dependenceGraph, 57 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 58 LinalgFusionOptions().setIndicesToFuse({2}), 59 LinalgTransformationFilter(Identifier::get("out_fusion", context), 60 Identifier::get("after_out_fusion", context)), 61 LinalgTransformationFilter( 62 ArrayRef<Identifier>(), 63 Identifier::get("after_out_fusion_producer", context)), 64 LinalgTransformationFilter( 65 ArrayRef<Identifier>(), 66 Identifier::get("after_out_fusion_original", context))); 67 68 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 69 context, dependenceGraph, 70 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 71 LinalgFusionOptions().setIndicesToFuse({1}), 72 LinalgTransformationFilter(Identifier::get("rhs_fusion", context), 73 Identifier::get("after_rhs_fusion", context)), 74 LinalgTransformationFilter( 75 ArrayRef<Identifier>(), 76 Identifier::get("after_rhs_fusion_producer", context)), 77 LinalgTransformationFilter( 78 ArrayRef<Identifier>(), 79 Identifier::get("after_rhs_fusion_original", context))); 80 81 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 82 context, dependenceGraph, 83 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 84 LinalgFusionOptions().setIndicesToFuse({0, 2}), 85 LinalgTransformationFilter( 86 Identifier::get("two_operand_fusion", context), 87 Identifier::get("after_two_operand_fusion", context)), 88 LinalgTransformationFilter( 89 ArrayRef<Identifier>(), 90 Identifier::get("after_two_operand_fusion_producer", context)), 91 LinalgTransformationFilter( 92 ArrayRef<Identifier>(), 93 Identifier::get("after_two_operand_fusion_original", context))); 94 95 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 96 context, dependenceGraph, 97 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 98 LinalgFusionOptions().setIndicesToFuse({0, 1}), 99 LinalgTransformationFilter( 100 Identifier::get("transpose_fusion", context), 101 Identifier::get("after_transpose_fusion", context)), 102 LinalgTransformationFilter( 103 ArrayRef<Identifier>(), 104 Identifier::get("after_transpose_fusion_producer", context)), 105 LinalgTransformationFilter( 106 ArrayRef<Identifier>(), 107 Identifier::get("after_transpose_fusion_original", context))); 108 } 109 110 namespace { 111 template <LinalgTilingLoopType LoopType> 112 struct TestLinalgFusionTransforms 113 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> { 114 void getDependentDialects(DialectRegistry ®istry) const override { 115 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 116 scf::SCFDialect, StandardOpsDialect>(); 117 } 118 TestLinalgFusionTransforms() = default; 119 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 120 121 void runOnFunction() override { 122 MLIRContext *context = &this->getContext(); 123 FuncOp funcOp = this->getFunction(); 124 RewritePatternSet fusionPatterns(context); 125 Aliases alias; 126 LinalgDependenceGraph dependenceGraph = 127 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 128 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 129 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 130 } 131 }; 132 133 struct TestLinalgFusionTransformsParallelLoops 134 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 135 StringRef getArgument() const final { 136 return "test-linalg-fusion-transform-patterns"; 137 } 138 StringRef getDescription() const final { 139 return "Test Linalg fusion transformation patterns by applying them " 140 "greedily."; 141 } 142 }; 143 144 struct TestLinalgFusionTransformsLoops 145 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 146 StringRef getArgument() const final { 147 return "test-linalg-tensor-fusion-transform-patterns"; 148 } 149 StringRef getDescription() const final { 150 return "Test Linalg on tensor fusion transformation " 151 "patterns by applying them greedily."; 152 } 153 }; 154 155 struct TestLinalgFusionTransformsTiledLoops 156 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 157 StringRef getArgument() const final { 158 return "test-linalg-tiled-loop-fusion-transform-patterns"; 159 } 160 StringRef getDescription() const final { 161 return "Test Linalg on tensor fusion transformation " 162 "patterns by applying them greedily."; 163 } 164 }; 165 } // namespace 166 167 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 168 OpBuilder b(f); 169 DenseSet<Operation *> eraseSet; 170 171 // Save original Linalg ops, we only want to make a pass over those. 172 SmallVector<LinalgOp, 8> linalgOps; 173 f.walk([&](LinalgOp op) { 174 // TODO: support multi-results. 175 if (op->getNumResults() <= 1) 176 linalgOps.push_back(op); 177 }); 178 179 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 180 bool changed = false; 181 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 182 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 183 if (opOperand->get().getType().isa<MemRefType>()) { 184 // TODO: LinalgDependenceGraph should be able to update itself. 185 // The current naive and expensive reconstruction of the graph should be 186 // removed. 187 linalg::Aliases aliases; 188 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 189 if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) { 190 auto *originalOp = info->originalProducer.getOperation(); 191 eraseSet.insert(originalOp); 192 auto *originalOpInLinalgOpsVector = 193 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 194 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 195 changed = true; 196 } 197 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 198 // Tile and Fuse tensor input. 199 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 200 continue; 201 if (auto info = fuseProducerOfTensor(b, *opOperand)) { 202 auto *originalOp = info->originalProducer.getOperation(); 203 auto *originalOpInLinalgOpsVector = 204 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 205 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 206 // Don't mark for erasure in the tensor case, let DCE handle this. 207 changed = true; 208 } 209 } 210 } 211 } 212 // The `fuseProducerOfBuffer` function performs structural checks and in 213 // particular that no covering read or write exist between the consumer and 214 // the producer. As a consequence, the only fusions that may occur preserve 215 // subsequent dependences and are guaranteed by construction to produce the 216 // whole view. We may thus erase the producer once it is fused. 217 for (auto *e : eraseSet) 218 e->erase(); 219 220 return changed ? success() : failure(); 221 } 222 223 namespace { 224 struct TestLinalgGreedyFusion 225 : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> { 226 void getDependentDialects(DialectRegistry ®istry) const override { 227 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 228 scf::SCFDialect>(); 229 } 230 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 231 StringRef getDescription() const final { 232 return "Test Linalg fusion by applying a greedy test transformation."; 233 } 234 void runOnFunction() override { 235 MLIRContext *context = &getContext(); 236 RewritePatternSet patterns = 237 linalg::getLinalgTilingCanonicalizationPatterns(context); 238 patterns.add<AffineMinSCFCanonicalizationPattern, 239 ExtractSliceOfPadTensorSwapPattern>(context); 240 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 241 do { 242 (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); 243 PassManager pm(context); 244 pm.addPass(createLoopInvariantCodeMotionPass()); 245 pm.addPass(createCanonicalizerPass()); 246 pm.addPass(createCSEPass()); 247 LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>()); 248 if (failed(res)) 249 this->signalPassFailure(); 250 } while (succeeded(fuseLinalgOpsGreedily(getFunction()))); 251 } 252 }; 253 254 /// Pass to test tile and fuse of sequence of operations. Intended only for 255 /// testing. 256 struct TestLinalgTileAndFuseSequencePass 257 : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> { 258 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 259 StringRef getDescription() const final { 260 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 261 } 262 TestLinalgTileAndFuseSequencePass() = default; 263 TestLinalgTileAndFuseSequencePass( 264 const TestLinalgTileAndFuseSequencePass &pass){}; 265 266 ListOption<int64_t> tileSizes{ 267 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), 268 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 269 270 void getDependentDialects(DialectRegistry ®istry) const override { 271 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 272 scf::SCFDialect>(); 273 } 274 275 void runOnFunction() override { 276 FuncOp funcOp = getOperation(); 277 auto &blocks = funcOp.getBody().getBlocks(); 278 if (!llvm::hasSingleElement(blocks)) { 279 return; 280 } 281 SmallVector<LinalgOp, 2> linalgOps = 282 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 283 Aliases aliases; 284 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 285 OpBuilder builder(funcOp.getContext()); 286 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 287 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 288 return linalgOp.hasTensorSemantics(); 289 })) 290 loopType = LinalgTilingLoopType::Loops; 291 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 292 builder, linalgOps, dependenceGraph, 293 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 294 if (!tileAndFuseOps) 295 return signalPassFailure(); 296 if (linalgOps.back().hasTensorSemantics()) { 297 linalgOps.back().getOperation()->replaceAllUsesWith( 298 tileAndFuseOps->fusedLoops.front()); 299 } 300 for (auto op : linalgOps) 301 if (op.hasBufferSemantics()) 302 op.erase(); 303 } 304 }; 305 306 } // namespace 307 308 namespace mlir { 309 namespace test { 310 void registerTestLinalgFusionTransforms() { 311 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 312 } 313 void registerTestLinalgTensorFusionTransforms() { 314 PassRegistration<TestLinalgFusionTransformsLoops>(); 315 } 316 void registerTestLinalgTiledLoopFusionTransforms() { 317 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 318 } 319 void registerTestLinalgGreedyFusion() { 320 PassRegistration<TestLinalgGreedyFusion>(); 321 } 322 void registerTestLinalgTileAndFuseSequencePass() { 323 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 324 } 325 326 } // namespace test 327 } // namespace mlir 328