1 //===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg fusion patterns. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Pass/Pass.h" 16 #include "mlir/Pass/PassManager.h" 17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 18 #include "mlir/Transforms/Passes.h" 19 20 using namespace mlir; 21 using namespace mlir::linalg; 22 23 template <LinalgTilingLoopType LoopType> 24 static void fillFusionPatterns(MLIRContext *context, 25 const LinalgDependenceGraph &dependenceGraph, 26 RewritePatternSet &patterns) { 27 patterns.add<LinalgTileAndFusePattern<MatmulOp>, 28 LinalgTileAndFusePattern<ConvOp>>( 29 context, dependenceGraph, 30 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 31 LinalgFusionOptions().setIndicesToFuse({2}), 32 LinalgTransformationFilter( 33 Identifier::get("basic_fusion", context), 34 Identifier::get("after_basic_fusion", context)), 35 LinalgTransformationFilter( 36 ArrayRef<Identifier>(), 37 Identifier::get("after_basic_fusion_producer", context)), 38 LinalgTransformationFilter( 39 ArrayRef<Identifier>(), 40 Identifier::get("after_basic_fusion_original", context))); 41 42 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 43 context, dependenceGraph, 44 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 45 LinalgFusionOptions().setIndicesToFuse({0}), 46 LinalgTransformationFilter(Identifier::get("lhs_fusion", context), 47 Identifier::get("after_lhs_fusion", context)), 48 LinalgTransformationFilter( 49 ArrayRef<Identifier>(), 50 Identifier::get("after_lhs_fusion_producer", context)), 51 LinalgTransformationFilter( 52 ArrayRef<Identifier>(), 53 Identifier::get("after_lhs_fusion_original", context))); 54 55 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 56 context, dependenceGraph, 57 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 58 LinalgFusionOptions().setIndicesToFuse({2}), 59 LinalgTransformationFilter(Identifier::get("out_fusion", context), 60 Identifier::get("after_out_fusion", context)), 61 LinalgTransformationFilter( 62 ArrayRef<Identifier>(), 63 Identifier::get("after_out_fusion_producer", context)), 64 LinalgTransformationFilter( 65 ArrayRef<Identifier>(), 66 Identifier::get("after_out_fusion_original", context))); 67 68 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 69 context, dependenceGraph, 70 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 71 LinalgFusionOptions().setIndicesToFuse({1}), 72 LinalgTransformationFilter(Identifier::get("rhs_fusion", context), 73 Identifier::get("after_rhs_fusion", context)), 74 LinalgTransformationFilter( 75 ArrayRef<Identifier>(), 76 Identifier::get("after_rhs_fusion_producer", context)), 77 LinalgTransformationFilter( 78 ArrayRef<Identifier>(), 79 Identifier::get("after_rhs_fusion_original", context))); 80 81 patterns.add<LinalgTileAndFusePattern<MatmulOp>>( 82 context, dependenceGraph, 83 LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), 84 LinalgFusionOptions().setIndicesToFuse({0, 2}), 85 LinalgTransformationFilter( 86 Identifier::get("two_operand_fusion", context), 87 Identifier::get("after_two_operand_fusion", context)), 88 LinalgTransformationFilter( 89 ArrayRef<Identifier>(), 90 Identifier::get("after_two_operand_fusion_producer", context)), 91 LinalgTransformationFilter( 92 ArrayRef<Identifier>(), 93 Identifier::get("after_two_operand_fusion_original", context))); 94 95 patterns.add<LinalgTileAndFusePattern<GenericOp>>( 96 context, dependenceGraph, 97 LinalgTilingOptions().setTileSizes({32, 64}).setLoopType(LoopType), 98 LinalgFusionOptions().setIndicesToFuse({0, 1}), 99 LinalgTransformationFilter( 100 Identifier::get("transpose_fusion", context), 101 Identifier::get("after_transpose_fusion", context)), 102 LinalgTransformationFilter( 103 ArrayRef<Identifier>(), 104 Identifier::get("after_transpose_fusion_producer", context)), 105 LinalgTransformationFilter( 106 ArrayRef<Identifier>(), 107 Identifier::get("after_transpose_fusion_original", context))); 108 } 109 110 namespace { 111 template <LinalgTilingLoopType LoopType = LinalgTilingLoopType::ParallelLoops> 112 struct TestLinalgFusionTransforms 113 : public PassWrapper<TestLinalgFusionTransforms<LoopType>, FunctionPass> { 114 TestLinalgFusionTransforms() = default; 115 TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} 116 117 void getDependentDialects(DialectRegistry ®istry) const override { 118 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 119 scf::SCFDialect, StandardOpsDialect>(); 120 } 121 122 void runOnFunction() override { 123 MLIRContext *context = &this->getContext(); 124 FuncOp funcOp = this->getFunction(); 125 RewritePatternSet fusionPatterns(context); 126 Aliases alias; 127 LinalgDependenceGraph dependenceGraph = 128 LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); 129 fillFusionPatterns<LoopType>(context, dependenceGraph, fusionPatterns); 130 (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); 131 } 132 }; 133 } // namespace 134 135 static LogicalResult fuseLinalgOpsGreedily(FuncOp f) { 136 OpBuilder b(f); 137 DenseSet<Operation *> eraseSet; 138 139 // Save original Linalg ops, we only want to make a pass over those. 140 SmallVector<LinalgOp, 8> linalgOps; 141 f.walk([&](LinalgOp op) { 142 // TODO: support multi-results. 143 if (op->getNumResults() <= 1) 144 linalgOps.push_back(op); 145 }); 146 147 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 148 bool changed = false; 149 for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { 150 for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { 151 if (opOperand->get().getType().isa<MemRefType>()) { 152 // TODO: LinalgDependenceGraph should be able to update itself. 153 // The current naive and expensive reconstruction of the graph should be 154 // removed. 155 linalg::Aliases aliases; 156 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 157 if (auto info = fuseProducerOfBuffer(b, *opOperand, graph)) { 158 auto *originalOp = info->originalProducer.getOperation(); 159 eraseSet.insert(originalOp); 160 auto *originalOpInLinalgOpsVector = 161 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 162 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 163 changed = true; 164 } 165 } else { 166 assert(opOperand->get().getType().isa<RankedTensorType>()); 167 // Tile and Fuse tensor input. 168 if (opOperand->getOperandNumber() >= linalgOp.getNumInputs()) 169 continue; 170 if (auto info = fuseProducerOfTensor(b, *opOperand)) { 171 auto *originalOp = info->originalProducer.getOperation(); 172 auto *originalOpInLinalgOpsVector = 173 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 174 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 175 // Don't mark for erasure in the tensor case, let DCE handle this. 176 changed = true; 177 } 178 } 179 } 180 } 181 // The `fuseProducerOfBuffer` function performs structural checks and in 182 // particular that no covering read or write exist between the consumer and 183 // the producer. As a consequence, the only fusions that may occur preserve 184 // subsequent dependences and are guaranteed by construction to produce the 185 // whole view. We may thus erase the producer once it is fused. 186 for (auto *e : eraseSet) 187 e->erase(); 188 189 return changed ? success() : failure(); 190 } 191 192 namespace { 193 struct TestLinalgGreedyFusion 194 : public PassWrapper<TestLinalgGreedyFusion, FunctionPass> { 195 void getDependentDialects(DialectRegistry ®istry) const override { 196 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 197 scf::SCFDialect>(); 198 } 199 void runOnFunction() override { 200 MLIRContext *context = &getContext(); 201 RewritePatternSet patterns = 202 linalg::getLinalgTilingCanonicalizationPatterns(context); 203 patterns.add<AffineMinSCFCanonicalizationPattern>(context); 204 FrozenRewritePatternSet frozenPatterns(std::move(patterns)); 205 while (succeeded(fuseLinalgOpsGreedily(getFunction()))) { 206 (void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns); 207 PassManager pm(context); 208 pm.addPass(createLoopInvariantCodeMotionPass()); 209 pm.addPass(createCanonicalizerPass()); 210 pm.addPass(createCSEPass()); 211 LogicalResult res = pm.run(getFunction()->getParentOfType<ModuleOp>()); 212 if (failed(res)) 213 this->signalPassFailure(); 214 } 215 } 216 }; 217 218 /// Pass to test tile and fuse of sequence of operations. Intended only for 219 /// testing. 220 struct TestLinalgTileAndFuseSequencePass 221 : public PassWrapper<TestLinalgTileAndFuseSequencePass, FunctionPass> { 222 TestLinalgTileAndFuseSequencePass() = default; 223 TestLinalgTileAndFuseSequencePass( 224 const TestLinalgTileAndFuseSequencePass &pass){}; 225 226 ListOption<int64_t> tileSizes{ 227 *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), 228 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 229 230 void getDependentDialects(DialectRegistry ®istry) const override { 231 registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect, 232 scf::SCFDialect>(); 233 } 234 235 void runOnFunction() override { 236 FuncOp funcOp = getOperation(); 237 auto &blocks = funcOp.getBody().getBlocks(); 238 if (!llvm::hasSingleElement(blocks)) { 239 return; 240 } 241 SmallVector<LinalgOp, 2> linalgOps = 242 llvm::to_vector<2>(blocks.front().getOps<LinalgOp>()); 243 Aliases aliases; 244 LinalgDependenceGraph dependenceGraph(aliases, linalgOps); 245 OpBuilder builder(funcOp.getContext()); 246 linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; 247 if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { 248 return linalgOp.hasTensorSemantics(); 249 })) 250 loopType = LinalgTilingLoopType::Loops; 251 Optional<TiledAndFusedLinalgOps> tileAndFuseOps = tileAndFuseLinalgOps( 252 builder, linalgOps, dependenceGraph, 253 LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); 254 if (!tileAndFuseOps) 255 return signalPassFailure(); 256 if (linalgOps.back().hasTensorSemantics()) { 257 linalgOps.back().getOperation()->replaceAllUsesWith( 258 tileAndFuseOps->fusedLoops.front()); 259 } 260 for (auto op : linalgOps) 261 if (op.hasBufferSemantics()) 262 op.erase(); 263 } 264 }; 265 } // namespace 266 267 namespace mlir { 268 namespace test { 269 void registerTestLinalgFusionTransforms() { 270 PassRegistration<TestLinalgFusionTransforms<>> testFusionTransformsPass( 271 "test-linalg-fusion-transform-patterns", 272 "Test Linalg fusion transformation patterns by applying them greedily."); 273 } 274 void registerTestLinalgTensorFusionTransforms() { 275 PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::Loops>> 276 testTensorFusionTransformsPass( 277 "test-linalg-tensor-fusion-transform-patterns", 278 "Test Linalg on tensor fusion transformation " 279 "patterns by applying them greedily."); 280 } 281 void registerTestLinalgTiledLoopFusionTransforms() { 282 PassRegistration<TestLinalgFusionTransforms<LinalgTilingLoopType::TiledLoops>> 283 testTiledLoopFusionTransformsPass( 284 "test-linalg-tiled-loop-fusion-transform-patterns", 285 "Test Linalg on tensor fusion transformation " 286 "patterns by applying them greedily."); 287 } 288 void registerTestLinalgGreedyFusion() { 289 PassRegistration<TestLinalgGreedyFusion> testFusionTransformsPass( 290 "test-linalg-greedy-fusion", 291 "Test Linalg fusion by applying a greedy test transformation."); 292 } 293 void registerTestLinalgTileAndFuseSequencePass() { 294 PassRegistration<TestLinalgTileAndFuseSequencePass> 295 testTileAndFuseSequencePass( 296 "test-linalg-tile-and-fuse", 297 "Test Linalg tiling and fusion of a sequence of Linalg operations."); 298 } 299 300 } // namespace test 301 } // namespace mlir 302