1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Func/IR/FuncOps.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 16 #include "mlir/Dialect/SCF/Transforms.h" 17 #include "mlir/Pass/Pass.h" 18 #include "mlir/Pass/PassManager.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "mlir/Transforms/Passes.h" 21 22 using namespace mlir; 23 using namespace mlir::linalg; 24 25 template <LinalgTilingLoopType LoopType> 26 static void fillFusionPatterns(MLIRContext *context, 27 const LinalgDependenceGraph &dependenceGraph, 28 RewritePatternSet &patterns) { 29 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 30 LinalgTileAndFusePattern<Conv2DOp>>( 31 context, dependenceGraph, 32 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 33 LinalgFusionOptions().setIndicesToFuse({2}), 34 LinalgTransformationFilter( 35 StringAttr::get(context, "basic_fusion"), 36 StringAttr::get(context, "after_basic_fusion")), 37 LinalgTransformationFilter( 38 ArrayRef<StringAttr>(), 39 StringAttr::get(context, "after_basic_fusion_producer")), 40 LinalgTransformationFilter( 41 ArrayRef<StringAttr>(), 42 StringAttr::get(context, "after_basic_fusion_original"))); 43 44 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 45 context, dependenceGraph, 46 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 47 LinalgFusionOptions().setIndicesToFuse({0}), 48 LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), 49 StringAttr::get(context, "after_lhs_fusion")), 50 LinalgTransformationFilter( 51 ArrayRef<StringAttr>(), 52 StringAttr::get(context, "after_lhs_fusion_producer")), 53 LinalgTransformationFilter( 54 ArrayRef<StringAttr>(), 55 StringAttr::get(context, "after_lhs_fusion_original"))); 56 57 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 58 context, dependenceGraph, 59 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 60 LinalgFusionOptions().setIndicesToFuse({2}), 61 LinalgTransformationFilter(StringAttr::get(context, "out_fusion"), 62 StringAttr::get(context, "after_out_fusion")), 63 LinalgTransformationFilter( 64 ArrayRef<StringAttr>(), 65 StringAttr::get(context, "after_out_fusion_producer")), 66 LinalgTransformationFilter( 67 ArrayRef<StringAttr>(), 68 StringAttr::get(context, "after_out_fusion_original"))); 69 70 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 71 context, dependenceGraph, 72 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 73 LinalgFusionOptions().setIndicesToFuse({1}), 74 LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), 75 StringAttr::get(context, "after_rhs_fusion")), 76 LinalgTransformationFilter( 77 ArrayRef<StringAttr>(), 78 StringAttr::get(context, "after_rhs_fusion_producer")), 79 LinalgTransformationFilter( 80 ArrayRef<StringAttr>(), 81 StringAttr::get(context, "after_rhs_fusion_original"))); 82 83 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 84 context, dependenceGraph, 85 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 86 LinalgFusionOptions().setIndicesToFuse({0, 2}), 87 LinalgTransformationFilter( 88 StringAttr::get(context, "two_operand_fusion"), 89 StringAttr::get(context, "after_two_operand_fusion")), 90 LinalgTransformationFilter( 91 ArrayRef<StringAttr>(), 92 StringAttr::get(context, "after_two_operand_fusion_producer")), 93 LinalgTransformationFilter( 94 ArrayRef<StringAttr>(), 95 StringAttr::get(context, "after_two_operand_fusion_original"))); 96 97 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 98 context, dependenceGraph, 99 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 100 LinalgFusionOptions().setIndicesToFuse({0, 1}), 101 LinalgTransformationFilter( 102 StringAttr::get(context, "transpose_fusion"), 103 StringAttr::get(context, "after_transpose_fusion")), 104 LinalgTransformationFilter( 105 ArrayRef<StringAttr>(), 106 StringAttr::get(context, "after_transpose_fusion_producer")), 107 LinalgTransformationFilter( 108 ArrayRef<StringAttr>(), 109 StringAttr::get(context, "after_transpose_fusion_original"))); 110 } 111 112 namespace { 113 template <LinalgTilingLoopType LoopType> 114 struct TestLinalgFusionTransforms 115 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, 116 OperationPass<FuncOp>> { 117 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms) 118 119 void getDependentDialects(DialectRegistry ®istry) const override { 120 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 121 scf::SCFDialect>(); 122 } 123 TestLinalgFusionTransforms() = default; 124 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 125 126 void runOnOperation() override { 127 MLIRContext *context = &this->getContext(); 128 FuncOp funcOp = this->getOperation(); 129 RewritePatternSet fusionPatterns(context); 130 Aliases alias; 131 LinalgDependenceGraph dependenceGraph = 132 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 133 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 134 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 135 } 136 }; 137 138 struct TestLinalgFusionTransformsParallelLoops 139 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 140 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 141 TestLinalgFusionTransformsParallelLoops) 142 143 StringRef getArgument() const final { 144 return "test-linalg-fusion-transform-patterns"; 145 } 146 StringRef getDescription() const final { 147 return "Test Linalg fusion transformation patterns by applying them " 148 "greedily."; 149 } 150 }; 151 152 struct TestLinalgFusionTransformsLoops 153 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 154 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops) 155 156 StringRef getArgument() const final { 157 return "test-linalg-tensor-fusion-transform-patterns"; 158 } 159 StringRef getDescription() const final { 160 return "Test Linalg on tensor fusion transformation " 161 "patterns by applying them greedily."; 162 } 163 }; 164 165 struct TestLinalgFusionTransformsTiledLoops 166 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 167 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 168 TestLinalgFusionTransformsTiledLoops) 169 170 StringRef getArgument() const final { 171 return "test-linalg-tiled-loop-fusion-transform-patterns"; 172 } 173 StringRef getDescription() const final { 174 return "Test Linalg on tensor fusion transformation " 175 "patterns by applying them greedily."; 176 } 177 }; 178 } // namespace 179 180 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 181 OpBuilder b(f); 182 DenseSet<Operation *> eraseSet; 183 184 // Save original Linalg ops, we only want to make a pass over those. 185 SmallVector<LinalgOp, 8> linalgOps; 186 f.walk([&](LinalgOp op) { 187 // TODO: support multi-results. 188 if (op->getNumResults() <= 1) 189 linalgOps.push_back(op); 190 }); 191 192 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 193 bool changed = false; 194 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 195 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 196 if (opOperand->get().getType().isa<MemRefType>()) { 197 // TODO: LinalgDependenceGraph should be able to update itself. 198 // The current naive and expensive reconstruction of the graph should be 199 // removed. 200 linalg::Aliases aliases; 201 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 202 auto info = fuseProducerOfBuffer(b, *opOperand, graph); 203 if (failed(info)) 204 continue; 205 auto *originalOp = info->originalProducer.getOperation(); 206 eraseSet.insert(originalOp); 207 auto *originalOpInLinalgOpsVector = 208 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 209 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 210 changed = true; 211 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 212 // Tile and Fuse tensor input. 213 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 214 continue; 215 auto info = fuseProducerOfTensor(b, *opOperand); 216 if (failed(info)) 217 continue; 218 auto *originalOp = info->originalProducer.getOperation(); 219 auto *originalOpInLinalgOpsVector = 220 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 221 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 222 // Don't mark for erasure in the tensor case, let DCE handle this. 223 changed = true; 224 } 225 } 226 } 227 // The `fuseProducerOfBuffer` function performs structural checks and in 228 // particular that no covering read or write exist between the consumer and 229 // the producer. As a consequence, the only fusions that may occur preserve 230 // subsequent dependences and are guaranteed by construction to produce the 231 // whole view. We may thus erase the producer once it is fused. 232 for (auto *e : eraseSet) 233 e->erase(); 234 235 return changed ? success() : failure(); 236 } 237 238 namespace { 239 struct TestLinalgGreedyFusion 240 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<FuncOp>> { 241 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion) 242 243 void getDependentDialects(DialectRegistry ®istry) const override { 244 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 245 scf::SCFDialect>(); 246 } 247 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 248 StringRef getDescription() const final { 249 return "Test Linalg fusion by applying a greedy test transformation."; 250 } 251 void runOnOperation() override { 252 MLIRContext *context = &getContext(); 253 RewritePatternSet patterns = 254 linalg::getLinalgTilingCanonicalizationPatterns(context); 255 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 256 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 257 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 258 do { 259 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); 260 PassManager pm(context); 261 pm.addPass(createLoopInvariantCodeMotionPass()); 262 pm.addPass(createCanonicalizerPass()); 263 pm.addPass(createCSEPass()); 264 LogicalResult res = pm.run(getOperation()->getParentOfType<ModuleOp>()); 265 if (failed(res)) 266 this->signalPassFailure(); 267 } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 268 } 269 }; 270 271 /// Pass to test tile and fuse of sequence of operations. Intended only for 272 /// testing. 273 struct TestLinalgTileAndFuseSequencePass 274 : public PassWrapper<TestLinalgTileAndFuseSequencePass, 275 OperationPass<FuncOp>> { 276 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 277 TestLinalgTileAndFuseSequencePass) 278 279 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 280 StringRef getDescription() const final { 281 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 282 } 283 TestLinalgTileAndFuseSequencePass() = default; 284 TestLinalgTileAndFuseSequencePass( 285 const TestLinalgTileAndFuseSequencePass &pass) 286 : PassWrapper(pass){}; 287 288 ListOption<int64_t> tileSizes{*this, "tile-sizes", 289 llvm::cl::desc("Tile sizes to use for ops"), 290 llvm::cl::ZeroOrMore}; 291 292 void getDependentDialects(DialectRegistry ®istry) const override { 293 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 294 scf::SCFDialect>(); 295 } 296 297 void runOnOperation() override { 298 FuncOp funcOp = getOperation(); 299 auto &blocks = funcOp.getBody().getBlocks(); 300 if (!llvm::hasSingleElement(blocks)) { 301 return; 302 } 303 SmallVector<LinalgOp, 2> linalgOps = 304 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 305 Aliases aliases; 306 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 307 OpBuilder builder(funcOp.getContext()); 308 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 309 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 310 return linalgOp.hasTensorSemantics(); 311 })) 312 loopType = LinalgTilingLoopType::Loops; 313 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 314 builder, linalgOps, dependenceGraph, 315 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 316 if (!tileAndFuseOps) 317 return signalPassFailure(); 318 if (linalgOps.back().hasTensorSemantics()) { 319 linalgOps.back().getOperation()->replaceAllUsesWith( 320 tileAndFuseOps->fusedLoops.front()); 321 } 322 for (auto op : linalgOps) 323 if (op.hasBufferSemantics()) 324 op.erase(); 325 } 326 }; 327 328 } // namespace 329 330 namespace mlir { 331 namespace test { 332 void registerTestLinalgFusionTransforms() { 333 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 334 } 335 void registerTestLinalgTensorFusionTransforms() { 336 PassRegistration<TestLinalgFusionTransformsLoops>(); 337 } 338 void registerTestLinalgTiledLoopFusionTransforms() { 339 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 340 } 341 void registerTestLinalgGreedyFusion() { 342 PassRegistration<TestLinalgGreedyFusion>(); 343 } 344 void registerTestLinalgTileAndFuseSequencePass() { 345 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 346 } 347 348 } // namespace test 349 } // namespace mlir 350