1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/SCF/Transforms.h" 18 #include "mlir/Pass/Pass.h" 19 #include "mlir/Pass/PassManager.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "mlir/Transforms/Passes.h" 22 23 using namespace mlir; 24 using namespace mlir::linalg; 25 26 template <LinalgTilingLoopType LoopType> 27 static void fillFusionPatterns(MLIRContext *context, 28 const LinalgDependenceGraph &dependenceGraph, 29 RewritePatternSet &patterns) { 30 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 31 LinalgTileAndFusePattern<Conv2DOp>>( 32 context, dependenceGraph, 33 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 34 LinalgFusionOptions().setIndicesToFuse({2}), 35 LinalgTransformationFilter( 36 StringAttr::get(context, "basic_fusion"), 37 StringAttr::get(context, "after_basic_fusion")), 38 LinalgTransformationFilter( 39 ArrayRef<StringAttr>(), 40 StringAttr::get(context, "after_basic_fusion_producer")), 41 LinalgTransformationFilter( 42 ArrayRef<StringAttr>(), 43 StringAttr::get(context, "after_basic_fusion_original"))); 44 45 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 46 context, dependenceGraph, 47 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 48 LinalgFusionOptions().setIndicesToFuse({0}), 49 LinalgTransformationFilter(StringAttr::get(context, "lhs_fusion"), 50 StringAttr::get(context, "after_lhs_fusion")), 51 LinalgTransformationFilter( 52 ArrayRef<StringAttr>(), 53 StringAttr::get(context, "after_lhs_fusion_producer")), 54 LinalgTransformationFilter( 55 ArrayRef<StringAttr>(), 56 StringAttr::get(context, "after_lhs_fusion_original"))); 57 58 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 59 context, dependenceGraph, 60 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 61 LinalgFusionOptions().setIndicesToFuse({2}), 62 LinalgTransformationFilter(StringAttr::get(context, "out_fusion"), 63 StringAttr::get(context, "after_out_fusion")), 64 LinalgTransformationFilter( 65 ArrayRef<StringAttr>(), 66 StringAttr::get(context, "after_out_fusion_producer")), 67 LinalgTransformationFilter( 68 ArrayRef<StringAttr>(), 69 StringAttr::get(context, "after_out_fusion_original"))); 70 71 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 72 context, dependenceGraph, 73 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 74 LinalgFusionOptions().setIndicesToFuse({1}), 75 LinalgTransformationFilter(StringAttr::get(context, "rhs_fusion"), 76 StringAttr::get(context, "after_rhs_fusion")), 77 LinalgTransformationFilter( 78 ArrayRef<StringAttr>(), 79 StringAttr::get(context, "after_rhs_fusion_producer")), 80 LinalgTransformationFilter( 81 ArrayRef<StringAttr>(), 82 StringAttr::get(context, "after_rhs_fusion_original"))); 83 84 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 85 context, dependenceGraph, 86 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 87 LinalgFusionOptions().setIndicesToFuse({0, 2}), 88 LinalgTransformationFilter( 89 StringAttr::get(context, "two_operand_fusion"), 90 StringAttr::get(context, "after_two_operand_fusion")), 91 LinalgTransformationFilter( 92 ArrayRef<StringAttr>(), 93 StringAttr::get(context, "after_two_operand_fusion_producer")), 94 LinalgTransformationFilter( 95 ArrayRef<StringAttr>(), 96 StringAttr::get(context, "after_two_operand_fusion_original"))); 97 98 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 99 context, dependenceGraph, 100 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 101 LinalgFusionOptions().setIndicesToFuse({0, 1}), 102 LinalgTransformationFilter( 103 StringAttr::get(context, "transpose_fusion"), 104 StringAttr::get(context, "after_transpose_fusion")), 105 LinalgTransformationFilter( 106 ArrayRef<StringAttr>(), 107 StringAttr::get(context, "after_transpose_fusion_producer")), 108 LinalgTransformationFilter( 109 ArrayRef<StringAttr>(), 110 StringAttr::get(context, "after_transpose_fusion_original"))); 111 } 112 113 namespace { 114 template <LinalgTilingLoopType LoopType> 115 struct TestLinalgFusionTransforms 116 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, 117 OperationPass<func::FuncOp>> { 118 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms) 119 120 void getDependentDialects(DialectRegistry ®istry) const override { 121 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 122 scf::SCFDialect>(); 123 } 124 TestLinalgFusionTransforms() = default; 125 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 126 127 void runOnOperation() override { 128 MLIRContext *context = &this->getContext(); 129 func::FuncOp funcOp = this->getOperation(); 130 RewritePatternSet fusionPatterns(context); 131 Aliases alias; 132 LinalgDependenceGraph dependenceGraph = 133 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 134 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 135 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 136 } 137 }; 138 139 struct TestLinalgFusionTransformsParallelLoops 140 : public TestLinalgFusionTransforms<LinalgTilingLoopType::ParallelLoops> { 141 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 142 TestLinalgFusionTransformsParallelLoops) 143 144 StringRef getArgument() const final { 145 return "test-linalg-fusion-transform-patterns"; 146 } 147 StringRef getDescription() const final { 148 return "Test Linalg fusion transformation patterns by applying them " 149 "greedily."; 150 } 151 }; 152 153 struct TestLinalgFusionTransformsLoops 154 : public TestLinalgFusionTransforms<LinalgTilingLoopType::Loops> { 155 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops) 156 157 StringRef getArgument() const final { 158 return "test-linalg-tensor-fusion-transform-patterns"; 159 } 160 StringRef getDescription() const final { 161 return "Test Linalg on tensor fusion transformation " 162 "patterns by applying them greedily."; 163 } 164 }; 165 166 struct TestLinalgFusionTransformsTiledLoops 167 : public TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops> { 168 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 169 TestLinalgFusionTransformsTiledLoops) 170 171 StringRef getArgument() const final { 172 return "test-linalg-tiled-loop-fusion-transform-patterns"; 173 } 174 StringRef getDescription() const final { 175 return "Test Linalg on tensor fusion transformation " 176 "patterns by applying them greedily."; 177 } 178 }; 179 } // namespace 180 181 static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { 182 OpBuilder b(f); 183 DenseSet<Operation *> eraseSet; 184 185 // Save original Linalg ops, we only want to make a pass over those. 186 SmallVector<LinalgOp, 8> linalgOps; 187 f.walk([&](LinalgOp op) { 188 // TODO: support multi-results. 189 if (op->getNumResults() <= 1) 190 linalgOps.push_back(op); 191 }); 192 193 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 194 bool changed = false; 195 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 196 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 197 if (opOperand->get().getType().isa<MemRefType>()) { 198 // TODO: LinalgDependenceGraph should be able to update itself. 199 // The current naive and expensive reconstruction of the graph should be 200 // removed. 201 linalg::Aliases aliases; 202 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 203 auto info = fuseProducerOfBuffer(b, *opOperand, graph); 204 if (failed(info)) 205 continue; 206 auto *originalOp = info->originalProducer.getOperation(); 207 eraseSet.insert(originalOp); 208 auto *originalOpInLinalgOpsVector = 209 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 210 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 211 changed = true; 212 } else if (opOperand->get().getType().isa<RankedTensorType>()) { 213 // Tile and Fuse tensor input. 214 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 215 continue; 216 auto info = fuseProducerOfTensor(b, *opOperand); 217 if (failed(info)) 218 continue; 219 auto *originalOp = info->originalProducer.getOperation(); 220 auto *originalOpInLinalgOpsVector = 221 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 222 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 223 // Don't mark for erasure in the tensor case, let DCE handle this. 224 changed = true; 225 } 226 } 227 } 228 // The `fuseProducerOfBuffer` function performs structural checks and in 229 // particular that no covering read or write exist between the consumer and 230 // the producer. As a consequence, the only fusions that may occur preserve 231 // subsequent dependences and are guaranteed by construction to produce the 232 // whole view. We may thus erase the producer once it is fused. 233 for (auto *e : eraseSet) 234 e->erase(); 235 236 return changed ? success() : failure(); 237 } 238 239 namespace { 240 struct TestLinalgGreedyFusion 241 : public PassWrapper<TestLinalgGreedyFusion, OperationPass<func::FuncOp>> { 242 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgGreedyFusion) 243 244 void getDependentDialects(DialectRegistry ®istry) const override { 245 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 246 scf::SCFDialect>(); 247 } 248 StringRef getArgument() const final { return "test-linalg-greedy-fusion"; } 249 StringRef getDescription() const final { 250 return "Test Linalg fusion by applying a greedy test transformation."; 251 } 252 void runOnOperation() override { 253 MLIRContext *context = &getContext(); 254 RewritePatternSet patterns = 255 linalg::getLinalgTilingCanonicalizationPatterns(context); 256 patterns.add<ExtractSliceOfPadTensorSwapPattern>(context); 257 scf::populateSCFForLoopCanonicalizationPatterns(patterns); 258 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 259 OpPassManager pm(func::FuncOp::getOperationName()); 260 pm.addPass(createLoopInvariantCodeMotionPass()); 261 pm.addPass(createCanonicalizerPass()); 262 pm.addPass(createCSEPass()); 263 do { 264 (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); 265 if (failed(runPipeline(pm, getOperation()))) 266 this->signalPassFailure(); 267 } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); 268 } 269 }; 270 271 /// Pass to test tile and fuse of sequence of operations. Intended only for 272 /// testing. 273 struct TestLinalgTileAndFuseSequencePass 274 : public PassWrapper<TestLinalgTileAndFuseSequencePass, 275 OperationPass<func::FuncOp>> { 276 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( 277 TestLinalgTileAndFuseSequencePass) 278 279 StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } 280 StringRef getDescription() const final { 281 return "Test Linalg tiling and fusion of a sequence of Linalg operations."; 282 } 283 TestLinalgTileAndFuseSequencePass() = default; 284 TestLinalgTileAndFuseSequencePass( 285 const TestLinalgTileAndFuseSequencePass &pass) 286 : PassWrapper(pass){}; 287 288 ListOption<int64_t> tileSizes{*this, "tile-sizes", 289 llvm::cl::desc("Tile sizes to use for ops"), 290 llvm::cl::ZeroOrMore}; 291 292 void getDependentDialects(DialectRegistry ®istry) const override { 293 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 294 scf::SCFDialect>(); 295 } 296 297 void runOnOperation() override { 298 func::FuncOp funcOp = getOperation(); 299 auto &blocks = funcOp.getBody().getBlocks(); 300 if (!llvm::hasSingleElement(blocks)) { 301 return; 302 } 303 SmallVector<LinalgOp, 2> linalgOps = 304 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 305 Aliases aliases; 306 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 307 OpBuilder builder(funcOp.getContext()); 308 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 309 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 310 return linalgOp.hasTensorSemantics(); 311 })) 312 loopType = LinalgTilingLoopType::Loops; 313 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 314 builder, linalgOps, dependenceGraph, 315 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 316 if (!tileAndFuseOps) 317 return signalPassFailure(); 318 if (linalgOps.back().hasTensorSemantics()) { 319 linalgOps.back().getOperation()->replaceAllUsesWith( 320 tileAndFuseOps->fusedLoops.front()); 321 } 322 for (auto op : linalgOps) 323 if (op.hasBufferSemantics()) 324 op.erase(); 325 } 326 }; 327 328 } // namespace 329 330 namespace mlir { 331 namespace test { 332 void registerTestLinalgFusionTransforms() { 333 PassRegistration<TestLinalgFusionTransformsParallelLoops>(); 334 } 335 void registerTestLinalgTensorFusionTransforms() { 336 PassRegistration<TestLinalgFusionTransformsLoops>(); 337 } 338 void registerTestLinalgTiledLoopFusionTransforms() { 339 PassRegistration<TestLinalgFusionTransformsTiledLoops>(); 340 } 341 void registerTestLinalgGreedyFusion() { 342 PassRegistration<TestLinalgGreedyFusion>(); 343 } 344 void registerTestLinalgTileAndFuseSequencePass() { 345 PassRegistration<TestLinalgTileAndFuseSequencePass>(); 346 } 347 348 } // namespace test 349 } // namespace mlir 350