1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/SCF/Transforms.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::linalg; 24 25 template <LinalgTilingLoopType LoopType> 26 static void fillFusionPatterns(MLIRContext *context, 27 const LinalgDependenceGraph &dependenceGraph, 28 RewritePatternSet &patterns) { 29 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 30 LinalgTileAndFusePattern<Conv2DOp>>( 31 context, dependenceGraph, 32 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 33 LinalgFusionOptions().setIndicesToFuse({2}), 34 LinalgTransformationFilter( 35 StringAttr::get(context, "basic_fusion"), 36 StringAttr::get(context, "after_basic_fusion")), 37 LinalgTransformationFilter( 38 ArrayRef<StringAttr>(), 39 StringAttr::get(context, "after_basic_fusion_producer")), 40 LinalgTransformationFilter( 41 ArrayRef<StringAttr>(), 42 StringAttr::get(context, "after_basic_fusion_original"))); 43 44 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 45 context, dependenceGraph, 46 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 47 LinalgFusionOptions().setIndicesToFuse({0}), 48 LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), 49 StringAttr::get(context, "after_lhs_fusion")), 50 LinalgTransformationFilter( 51 ArrayRef<StringAttr>(), 52 StringAttr::get(context, "after_lhs_fusion_producer")), 53 LinalgTransformationFilter( 54 ArrayRef<StringAttr>(), 55 StringAttr::get(context, "after_lhs_fusion_original"))); 56 57 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 58 context, dependenceGraph, 59 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 60 LinalgFusionOptions().setIndicesToFuse({2}), 61 LinalgTransformationFilter(StringAttr::get(context, "out_fusion"), 62 StringAttr::get(context, "after_out_fusion")), 63 LinalgTransformationFilter( 64 ArrayRef<StringAttr>(), 65 StringAttr::get(context, "after_out_fusion_producer")), 66 LinalgTransformationFilter( 67 ArrayRef<StringAttr>(), 68 StringAttr::get(context, "after_out_fusion_original"))); 69 70 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 71 context, dependenceGraph, 72 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 73 LinalgFusionOptions().setIndicesToFuse({1}), 74 LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), 75 StringAttr::get(context, "after_rhs_fusion")), 76 LinalgTransformationFilter( 77 ArrayRef<StringAttr>(), 78 StringAttr::get(context, "after_rhs_fusion_producer")), 79 LinalgTransformationFilter( 80 ArrayRef<StringAttr>(), 81 StringAttr::get(context, "after_rhs_fusion_original"))); 82 83 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 84 context, dependenceGraph, 85 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 86 LinalgFusionOptions().setIndicesToFuse({0, 2}), 87 LinalgTransformationFilter( 88 StringAttr::get(context, "two_operand_fusion"), 89 StringAttr::get(context, "after_two_operand_fusion")), 90 LinalgTransformationFilter( 91 ArrayRef<StringAttr>(), 92 StringAttr::get(context, "after_two_operand_fusion_producer")), 93 LinalgTransformationFilter( 94 ArrayRef<StringAttr>(), 95 StringAttr::get(context, "after_two_operand_fusion_original"))); 96 97 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 98 context, dependenceGraph, 99 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 100 LinalgFusionOptions().setIndicesToFuse({0, 1}), 101 LinalgTransformationFilter( 102 StringAttr::get(context, "transpose_fusion"), 103 StringAttr::get(context, "after_transpose_fusion")), 104 LinalgTransformationFilter( 105 ArrayRef<StringAttr>(), 106 StringAttr::get(context, "after_transpose_fusion_producer")), 107 LinalgTransformationFilter( 108 ArrayRef<StringAttr>(), 109 StringAttr::get(context, "after_transpose_fusion_original"))); 110 } 111 112 namespace { 113 template <LinalgTilingLoopType LoopType> 114 struct TestLinalgFusionTransforms 115 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, 116 OperationPass<FuncOp>> { 117 void getDependentDialects(DialectRegistry ®istry) const override { 118 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 119 scf::SCFDialect>(); 120 } 121 TestLinalgFusionTransforms() = default; 122 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 123 124 void runOnOperation() override { 125 MLIRContext *context = &this->getContext(); 126 FuncOp funcOp = this->getOperation(); 127 RewritePatternSet fusionPatterns(context); 128 Aliases alias; 129 LinalgDependenceGraph dependenceGraph = 130 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 131 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 132 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 133 } 134 }; 135 136 struct TestLinalgFusionTransformsParallelLoops 137 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 138 StringRef getArgument() const final { 139 return "test-linalg-fusion-transform-patterns"; 140 } 141 StringRef getDescription() const final { 142 return "Test Linalg fusion transformation patterns by applying them " 143 "greedily."; 144 } 145 }; 146 147 struct TestLinalgFusionTransformsLoops 148 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 149 StringRef getArgument() const final { 150 return "test-linalg-tensor-fusion-transform-patterns"; 151 } 152 StringRef getDescription() const final { 153 return "Test Linalg on tensor fusion transformation " 154 "patterns by applying them greedily."; 155 } 156 }; 157 158 struct TestLinalgFusionTransformsTiledLoops 159 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 160 StringRef getArgument() const final { 161 return "test-linalg-tiled-loop-fusion-transform-patterns"; 162 } 163 StringRef getDescription() const final { 164 return "Test Linalg on tensor fusion transformation " 165 "patterns by applying them greedily."; 166 } 167 }; 168 } // namespace 169 170 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 171 OpBuilder b(f); 172 DenseSet<Operation *> eraseSet; 173 174 // Save original Linalg ops, we only want to make a pass over those. 175 SmallVector<LinalgOp, 8> linalgOps; 176 f.walk([&](LinalgOp op) { 177 // TODO: support multi-results. 178 if (op->getNumResults() <= 1) 179 linalgOps.push_back(op); 180 }); 181 182 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 183 bool changed = false; 184 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 185 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 186 if (opOperand->get().getType().isa<MemRefType>()) { 187 // TODO: LinalgDependenceGraph should be able to update itself. 188 // The current naive and expensive reconstruction of the graph should be 189 // removed. 190 linalg::Aliases aliases; 191 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 192 auto info = fuseProducerOfBuffer(b, *opOperand, graph); 193 if (failed(info)) 194 continue; 195 auto *originalOp = info->originalProducer.getOperation(); 196 eraseSet.insert(originalOp); 197 auto *originalOpInLinalgOpsVector = 198 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 199 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 200 changed = true; 201 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 202 // Tile and Fuse tensor input. 203 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 204 continue; 205 auto info = fuseProducerOfTensor(b, *opOperand); 206 if (failed(info)) 207 continue; 208 auto *originalOp = info->originalProducer.getOperation(); 209 auto *originalOpInLinalgOpsVector = 210 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 211 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 212 // Don't mark for erasure in the tensor case, let DCE handle this. 213 changed = true; 214 } 215 } 216 } 217 // The `fuseProducerOfBuffer` function performs structural checks and in 218 // particular that no covering read or write exist between the consumer and 219 // the producer. As a consequence, the only fusions that may occur preserve 220 // subsequent dependences and are guaranteed by construction to produce the 221 // whole view. We may thus erase the producer once it is fused. 222 for (auto *e : eraseSet) 223 e->erase(); 224 225 return changed ? success() : failure(); 226 } 227 228 namespace { 229 struct TestLinalgGreedyFusion 230 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<FuncOp>> { 231 void getDependentDialects(DialectRegistry ®istry) const override { 232 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 233 scf::SCFDialect>(); 234 } 235 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 236 StringRef getDescription() const final { 237 return "Test Linalg fusion by applying a greedy test transformation."; 238 } 239 void runOnOperation() override { 240 MLIRContext *context = &getContext(); 241 RewritePatternSet patterns = 242 linalg::getLinalgTilingCanonicalizationPatterns(context); 243 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 244 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 245 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 246 do { 247 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); 248 PassManager pm(context); 249 pm.addPass(createLoopInvariantCodeMotionPass()); 250 pm.addPass(createCanonicalizerPass()); 251 pm.addPass(createCSEPass()); 252 LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>()); 253 if (failed(res)) 254 this->signalPassFailure(); 255 } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 256 } 257 }; 258 259 /// Pass to test tile and fuse of sequence of operations. Intended only for 260 /// testing. 261 struct TestLinalgTileAndFuseSequencePass 262 : public PassWrapper<TestLinalgTileAndFuseSequencePass, 263 OperationPass<FuncOp>> { 264 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 265 StringRef getDescription() const final { 266 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 267 } 268 TestLinalgTileAndFuseSequencePass() = default; 269 TestLinalgTileAndFuseSequencePass( 270 const TestLinalgTileAndFuseSequencePass &pass) 271 : PassWrapper(pass){}; 272 273 ListOption<int64_t> tileSizes{ 274 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), 275 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 276 277 void getDependentDialects(DialectRegistry ®istry) const override { 278 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 279 scf::SCFDialect>(); 280 } 281 282 void runOnOperation() override { 283 FuncOp funcOp = getOperation(); 284 auto &blocks = funcOp.getBody().getBlocks(); 285 if (!llvm::hasSingleElement(blocks)) { 286 return; 287 } 288 SmallVector<LinalgOp, 2> linalgOps = 289 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 290 Aliases aliases; 291 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 292 OpBuilder builder(funcOp.getContext()); 293 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 294 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 295 return linalgOp.hasTensorSemantics(); 296 })) 297 loopType = LinalgTilingLoopType::Loops; 298 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 299 builder, linalgOps, dependenceGraph, 300 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 301 if (!tileAndFuseOps) 302 return signalPassFailure(); 303 if (linalgOps.back().hasTensorSemantics()) { 304 linalgOps.back().getOperation()->replaceAllUsesWith( 305 tileAndFuseOps->fusedLoops.front()); 306 } 307 for (auto op : linalgOps) 308 if (op.hasBufferSemantics()) 309 op.erase(); 310 } 311 }; 312 313 } // namespace 314 315 namespace mlir { 316 namespace test { 317 void registerTestLinalgFusionTransforms() { 318 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 319 } 320 void registerTestLinalgTensorFusionTransforms() { 321 PassRegistration<TestLinalgFusionTransformsLoops>(); 322 } 323 void registerTestLinalgTiledLoopFusionTransforms() { 324 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 325 } 326 void registerTestLinalgGreedyFusion() { 327 PassRegistration<TestLinalgGreedyFusion>(); 328 } 329 void registerTestLinalgTileAndFuseSequencePass() { 330 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 331 } 332 333 } // namespace test 334 } // namespace mlir 335