1 //===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic for testing Linalg transformations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 15 #include "mlir/Dialect/GPU/GPUDialect.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/Transforms/HoistPadding.h" 18 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/IR/Ops.h" 22 #include "mlir/Dialect/Vector/VectorOps.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 25 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/ADT/SmallVector.h" 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 namespace { 33 struct TestLinalgTransforms 34 : public PassWrapper<TestLinalgTransforms, FunctionPass> { 35 TestLinalgTransforms() = default; 36 TestLinalgTransforms(const TestLinalgTransforms &pass) {} 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 // clang-format off 40 registry.insert<AffineDialect, 41 memref::MemRefDialect, 42 scf::SCFDialect, 43 StandardOpsDialect, 44 vector::VectorDialect, 45 gpu::GPUDialect>(); 46 // clang-format on 47 } 48 StringRef getArgument() const final { 49 return "test-linalg-transform-patterns"; 50 } 51 StringRef getDescription() const final { 52 return "Test Linalg transformation patterns by applying them greedily."; 53 } 54 55 void runOnFunction() override; 56 57 Option<bool> testPatterns{*this, "test-patterns", 58 llvm::cl::desc("Test a mixed set of patterns"), 59 llvm::cl::init(false)}; 60 Option<bool> testMatmulToVectorPatterns1dTiling{ 61 *this, "test-matmul-to-vector-patterns-tile-1d", 62 llvm::cl::desc( 63 "Test a fused pass that applies patterns from matmul to vectors via " 64 "1-d tiling"), 65 llvm::cl::init(false)}; 66 Option<bool> testMatmulToVectorPatterns2dTiling{ 67 *this, "test-matmul-to-vector-patterns-tile-2d", 68 llvm::cl::desc( 69 "Test a fused pass that applies patterns from matmul to vectors via " 70 "2-d tiling"), 71 llvm::cl::init(false)}; 72 Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options", 73 llvm::cl::desc("Test promotion options"), 74 llvm::cl::init(false)}; 75 Option<bool> testTileAndDistributionOptions{ 76 *this, "test-tile-and-distribute-options", 77 llvm::cl::desc("Test tile and distribute options"), 78 llvm::cl::init(false)}; 79 Option<bool> testVectorTransferForwardingPatterns{ 80 *this, "test-vector-transfer-forwarding-patterns", 81 llvm::cl::desc( 82 "Test a fused pass that forwards linalg.copy to vector.transfer"), 83 llvm::cl::init(false)}; 84 Option<bool> testGenericToVectorPattern{ 85 *this, "test-linalg-to-vector-patterns", 86 llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction " 87 "in vector.contract form"), 88 llvm::cl::init(false)}; 89 Option<bool> testTilePattern{*this, "test-tile-pattern", 90 llvm::cl::desc("Test tile pattern"), 91 llvm::cl::init(false)}; 92 Option<bool> testTileScalarizeDynamicDims{ 93 *this, "test-tile-scalarize-dynamic-dims", 94 llvm::cl::desc("Test tiling of dynamic dims by 1"), 95 llvm::cl::init(false)}; 96 Option<int> testHoistPadding{*this, "test-hoist-padding", 97 llvm::cl::desc("Test hoist padding"), 98 llvm::cl::init(0)}; 99 Option<bool> testPadPattern{*this, "test-pad-pattern", 100 llvm::cl::desc("Test pad pattern"), 101 llvm::cl::init(false)}; 102 Option<bool> testTransformPadTensor{ 103 *this, "test-transform-pad-tensor", 104 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 105 llvm::cl::init(false)}; 106 Option<bool> testGeneralizePadTensor{ 107 *this, "test-generalize-pad-tensor", 108 llvm::cl::desc("Test transform pad tensor by copying with generic ops"), 109 llvm::cl::init(false)}; 110 Option<bool> testSwapSubTensorPadTensor{ 111 *this, "test-swap-subtensor-padtensor", 112 llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " 113 "pad_tensor(subtensor)"), 114 llvm::cl::init(false)}; 115 ListOption<int64_t> paddedOperands{ 116 *this, "padded-operands", 117 llvm::cl::desc("Operands to pad when test-tile-pattern"), 118 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 119 ListOption<int64_t> nofoldOperands{ 120 *this, "nofold-operands", 121 llvm::cl::desc("Operands to set nofold when test-tile-pattern"), 122 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 123 ListOption<int64_t> packPaddings{ 124 *this, "pack-paddings", 125 llvm::cl::desc("Operand packing flags when test-pad-pattern"), 126 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 127 ListOption<int64_t> hoistPaddings{ 128 *this, "hoist-paddings", 129 llvm::cl::desc("Operand hoisting depths when test-pad-pattern"), 130 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 131 ListOption<int64_t> peeledLoops{ 132 *this, "peeled-loops", 133 llvm::cl::desc("Loops to be peeled when test-tile-pattern"), 134 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 135 ListOption<int64_t> tileSizes{ 136 *this, "tile-sizes", 137 llvm::cl::desc("Linalg tile sizes for test-tile-pattern"), 138 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; 139 ListOption<unsigned> testInterchangePattern{ 140 *this, "test-interchange-pattern", llvm::cl::MiscFlags::CommaSeparated, 141 llvm::cl::desc("Test the interchange pattern.")}; 142 ListOption<unsigned> testTiledLoopPeeling{ 143 *this, "test-tiled-loop-peeling", 144 llvm::cl::desc("Test peeling of linalg.tiled_loop ops"), 145 llvm::cl::OneOrMore, llvm::cl::MiscFlags::CommaSeparated}; 146 Option<bool> skipPartial{ 147 *this, "skip-partial", 148 llvm::cl::desc("Skip loops inside partial iterations during peeling"), 149 llvm::cl::init(false)}; 150 Option<std::string> loopType{ 151 *this, "loop-type", 152 llvm::cl::desc("Specify the type of loops to generate: for, parallel or " 153 "tiled_loop"), 154 llvm::cl::init("for")}; 155 Option<bool> testDecomposeConvolutionPattern{ 156 *this, "test-decompose-convolution-patterns", 157 llvm::cl::desc("Test a set of patterns to rewrite high-D convolution ops " 158 "into low-D ones"), 159 llvm::cl::init(false)}; 160 }; 161 } // end anonymous namespace 162 163 static void applyPatterns(FuncOp funcOp) { 164 MLIRContext *ctx = funcOp.getContext(); 165 RewritePatternSet patterns(ctx); 166 167 //===--------------------------------------------------------------------===// 168 // Linalg tiling patterns. 169 //===--------------------------------------------------------------------===// 170 patterns.add<LinalgTilingPattern<MatmulOp>>( 171 ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), 172 LinalgTransformationFilter(Identifier::get("MEM", ctx), 173 Identifier::get("L3", ctx))); 174 patterns.add<LinalgTilingPattern<MatmulOp>>( 175 ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), 176 LinalgTransformationFilter(Identifier::get("L3", ctx), 177 Identifier::get("L2", ctx))); 178 patterns.add<LinalgTilingPattern<MatmulOp>>( 179 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 180 LinalgTransformationFilter(Identifier::get("L2", ctx), 181 Identifier::get("L1", ctx))); 182 patterns.add<LinalgTilingPattern<MatmulOp>>( 183 ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), 184 LinalgTransformationFilter(Identifier::get("L1", ctx), 185 Identifier::get("REG", ctx))); 186 187 patterns.add<LinalgTilingPattern<MatvecOp>>( 188 ctx, 189 LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( 190 LinalgTilingLoopType::ParallelLoops), 191 LinalgTransformationFilter(ArrayRef<Identifier>{}, 192 Identifier::get("L1", ctx))); 193 194 patterns.add<LinalgTilingPattern<DotOp>>( 195 ctx, LinalgTilingOptions().setTileSizes(8000), 196 LinalgTransformationFilter( 197 ArrayRef<Identifier>{Identifier::get("MEM", ctx), 198 Identifier::get("L3", ctx), 199 Identifier::get("L2", ctx)}, 200 Identifier::get("REG", ctx))); 201 202 //===--------------------------------------------------------------------===// 203 // Linalg tiling and permutation patterns. 204 //===--------------------------------------------------------------------===// 205 patterns.add<LinalgTilingPattern<MatmulOp>>( 206 ctx, 207 LinalgTilingOptions() 208 .setTileSizes({2000, 3000, 4000}) 209 .setInterchange({1, 2, 0}), 210 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 211 Identifier::get("L2__with_perm__", ctx))); 212 patterns.add<LinalgTilingPattern<MatmulOp>>( 213 ctx, 214 LinalgTilingOptions() 215 .setTileSizes({200, 300, 400}) 216 .setInterchange({1, 0, 2}), 217 LinalgTransformationFilter(Identifier::get("L2__with_perm__", ctx), 218 Identifier::get("L1__with_perm__", ctx))); 219 patterns.add<LinalgTilingPattern<MatmulOp>>( 220 ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), 221 LinalgTransformationFilter(Identifier::get("L1__with_perm__", ctx), 222 Identifier::get("REG__with_perm__", ctx))); 223 224 patterns.add<LinalgTilingPattern<MatvecOp>>( 225 ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), 226 LinalgTransformationFilter(Identifier::get("__with_perm__", ctx), 227 Identifier::get("L1__with_perm__", ctx))); 228 229 patterns.add<LinalgTilingPattern<MatmulOp>>( 230 ctx, 231 LinalgTilingOptions() 232 .setTileSizes({16, 8, 4}) 233 .setInterchange({1, 2, 0}) 234 .setLoopType(LinalgTilingLoopType::ParallelLoops), 235 LinalgTransformationFilter( 236 Identifier::get("par__with_perm__", ctx), 237 Identifier::get("after_par__with_perm__", ctx))); 238 239 //===--------------------------------------------------------------------===// 240 // Linalg to loops patterns. 241 //===--------------------------------------------------------------------===// 242 patterns.add<LinalgLoweringPattern<DotOp>>( 243 ctx, 244 /*loweringType=*/LinalgLoweringType::Loops, 245 LinalgTransformationFilter(Identifier::get("REG", ctx))); 246 247 //===--------------------------------------------------------------------===// 248 // Linalg distribution patterns. 249 //===--------------------------------------------------------------------===// 250 LinalgLoopDistributionOptions distributionOptions; 251 252 //===--------------------------------------------------------------------===// 253 // Linalg to vector contraction patterns. 254 //===--------------------------------------------------------------------===// 255 patterns.add<LinalgVectorizationPattern>( 256 ctx, LinalgTransformationFilter(Identifier::get("VECTORIZE", ctx)) 257 .addOpFilter<MatmulOp, FillOp, CopyOp, GenericOp>()); 258 259 //===--------------------------------------------------------------------===// 260 // Linalg generic interchange pattern. 261 //===--------------------------------------------------------------------===// 262 patterns.add<GenericOpInterchangePattern>( 263 ctx, 264 /*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0}, 265 LinalgTransformationFilter(ArrayRef<Identifier>{}, 266 Identifier::get("PERMUTED", ctx))); 267 268 //===--------------------------------------------------------------------===// 269 // Linalg subview operands promotion. 270 //===--------------------------------------------------------------------===// 271 patterns.add<LinalgPromotionPattern<MatmulOp>>( 272 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 273 LinalgTransformationFilter(Identifier::get("_promote_views_", ctx), 274 Identifier::get("_views_promoted_", ctx))); 275 patterns.add<LinalgPromotionPattern<MatmulOp>>( 276 ctx, 277 LinalgPromotionOptions() 278 .setOperandsToPromote({0}) 279 .setUseFullTileBuffersByDefault(true), 280 LinalgTransformationFilter( 281 Identifier::get("_promote_first_view_", ctx), 282 Identifier::get("_first_view_promoted_", ctx))); 283 patterns.add<LinalgPromotionPattern<FillOp>>( 284 ctx, 285 LinalgPromotionOptions() 286 .setOperandsToPromote({1}) 287 .setUseFullTileBuffers({false, true}) 288 .setAlignment(32), 289 LinalgTransformationFilter( 290 Identifier::get("_promote_views_aligned_", ctx), 291 Identifier::get("_views_aligned_promoted_", ctx))); 292 293 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 294 295 // Drop the marker. 296 funcOp.walk([](LinalgOp op) { 297 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 298 }); 299 } 300 301 static void fillL1TilingAndMatmulToVectorPatterns( 302 FuncOp funcOp, StringRef startMarker, 303 SmallVectorImpl<RewritePatternSet> &patternsVector) { 304 MLIRContext *ctx = funcOp.getContext(); 305 patternsVector.emplace_back( 306 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 307 ctx, 308 LinalgTilingOptions() 309 .setTileSizes({8, 12, 16}) 310 .setInterchange({1, 0, 2}), 311 LinalgTransformationFilter(Identifier::get(startMarker, ctx), 312 Identifier::get("L1", ctx)))); 313 314 patternsVector.emplace_back( 315 ctx, 316 std::make_unique<LinalgPromotionPattern<MatmulOp>>( 317 ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), 318 LinalgTransformationFilter(Identifier::get("L1", ctx), 319 Identifier::get("VEC", ctx)))); 320 321 patternsVector.emplace_back( 322 ctx, std::make_unique<LinalgVectorizationPattern>( 323 MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), 324 LinalgTransformationFilter(Identifier::get("VEC", ctx)))); 325 patternsVector.back().add<LinalgVectorizationPattern>( 326 ctx, LinalgTransformationFilter().addFilter( 327 [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); })); 328 } 329 330 //===----------------------------------------------------------------------===// 331 // Test promotion callbacks 332 //===----------------------------------------------------------------------===// 333 334 // Allocation call back 335 static Optional<Value> allocCallBackFn(OpBuilder &b, memref::SubViewOp subView, 336 ArrayRef<Value> boundingSubViewSize, 337 DataLayout &layout) { 338 SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1); 339 return b 340 .create<memref::AllocOp>( 341 subView.getLoc(), 342 MemRefType::get(shape, subView.getType().getElementType(), 343 /*affineMapComposition =*/{}, 3), 344 boundingSubViewSize) 345 .getResult(); 346 } 347 348 // Deallocation callback 349 static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) { 350 b.create<memref::DeallocOp>(buffer.getLoc(), buffer); 351 return success(); 352 } 353 354 // Copy in call back 355 static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst, 356 bool isOutput) { 357 auto floatType = src.getType().cast<MemRefType>().getElementType(); 358 if (!floatType.isa<FloatType>()) 359 return failure(); 360 if (!isOutput) { 361 Value cst = b.create<arith::ConstantOp>(src.getLoc(), 362 FloatAttr::get(floatType, 42.0)); 363 b.create<FillOp>(src.getLoc(), cst, dst); 364 } 365 b.create<CopyOp>(src.getLoc(), src, dst); 366 return success(); 367 } 368 369 static void fillPromotionCallBackPatterns(MLIRContext *ctx, 370 RewritePatternSet &patterns) { 371 patterns.add<LinalgTilingPattern<MatmulOp>>( 372 ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), 373 LinalgTransformationFilter(Identifier::get("START", ctx), 374 Identifier::get("PROMOTE", ctx))); 375 patterns.add<LinalgPromotionPattern<MatmulOp>>( 376 ctx, 377 LinalgPromotionOptions() 378 .setOperandsToPromote({0, 2}) 379 .setUseFullTileBuffers({false, false}) 380 .setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn) 381 .setCopyInOutFns( 382 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 383 return copyCallBackFn(b, src, dst, false); 384 }, 385 [](OpBuilder &b, Value src, Value dst) -> LogicalResult { 386 return copyCallBackFn(b, src, dst, true); 387 }), 388 LinalgTransformationFilter(Identifier::get("PROMOTE", ctx))); 389 } 390 391 template <typename IdOp, typename NProcsOp> 392 static SmallVector<ProcInfo, 2> 393 getGpuProcIds(OpBuilder &b, Location loc, ArrayRef<Range> parallelLoopRanges) { 394 size_t count = std::min<size_t>(3, parallelLoopRanges.size()); 395 SmallVector<ProcInfo, 2> procInfo(count); 396 const char *xyz[] = {"x", "y", "z"}; 397 Type indexType = b.getIndexType(); 398 for (unsigned i = 0; i < count; ++i) { 399 procInfo[count - 1 - i] = { 400 b.create<IdOp>(loc, indexType, b.getStringAttr(xyz[i])), 401 b.create<NProcsOp>(loc, indexType, b.getStringAttr(xyz[i]))}; 402 } 403 return procInfo; 404 } 405 406 static void fillTileAndDistributePatterns(MLIRContext *context, 407 RewritePatternSet &patterns) { 408 { 409 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 410 cyclicNprocsEqNiters.distributionMethod.resize( 411 2, DistributionMethod::CyclicNumProcsEqNumIters); 412 cyclicNprocsEqNiters.procInfo = 413 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 414 patterns.add<LinalgTilingPattern<MatmulOp>>( 415 context, 416 LinalgTilingOptions() 417 .setTileSizes({8, 8, 4}) 418 .setLoopType(LinalgTilingLoopType::ParallelLoops) 419 .setDistributionOptions(cyclicNprocsEqNiters), 420 LinalgTransformationFilter( 421 Identifier::get("distribute1", context), 422 Identifier::get("after_distribute1", context))); 423 } 424 425 { 426 LinalgLoopDistributionOptions cyclicNprocsGeNiters; 427 cyclicNprocsGeNiters.distributionMethod.resize( 428 2, DistributionMethod::CyclicNumProcsGeNumIters); 429 cyclicNprocsGeNiters.procInfo = 430 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 431 patterns.add<LinalgTilingPattern<MatmulOp>>( 432 context, 433 LinalgTilingOptions() 434 .setTileSizes({8, 8, 4}) 435 .setLoopType(LinalgTilingLoopType::ParallelLoops) 436 .setDistributionOptions(cyclicNprocsGeNiters), 437 LinalgTransformationFilter( 438 Identifier::get("distribute2", context), 439 Identifier::get("after_distribute2", context))); 440 } 441 442 { 443 LinalgLoopDistributionOptions cyclicNprocsDefault; 444 cyclicNprocsDefault.distributionMethod.resize(2, 445 DistributionMethod::Cyclic); 446 cyclicNprocsDefault.procInfo = 447 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 448 patterns.add<LinalgTilingPattern<MatmulOp>>( 449 context, 450 LinalgTilingOptions() 451 .setTileSizes({8, 8, 4}) 452 .setLoopType(LinalgTilingLoopType::ParallelLoops) 453 .setDistributionOptions(cyclicNprocsDefault), 454 LinalgTransformationFilter( 455 Identifier::get("distribute3", context), 456 Identifier::get("after_distribute3", context))); 457 } 458 459 { 460 LinalgLoopDistributionOptions cyclicNprocsMixed1; 461 cyclicNprocsMixed1.distributionMethod = { 462 DistributionMethod::CyclicNumProcsEqNumIters, 463 DistributionMethod::CyclicNumProcsGeNumIters}; 464 cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 465 patterns.add<LinalgTilingPattern<MatmulOp>>( 466 context, 467 LinalgTilingOptions() 468 .setTileSizes({8, 8, 4}) 469 .setLoopType(LinalgTilingLoopType::ParallelLoops) 470 .setDistributionOptions(cyclicNprocsMixed1), 471 LinalgTransformationFilter( 472 Identifier::get("distribute4", context), 473 Identifier::get("after_distribute4", context))); 474 } 475 476 { 477 LinalgLoopDistributionOptions cyclicNprocsMixed2; 478 cyclicNprocsMixed2.distributionMethod = { 479 DistributionMethod::CyclicNumProcsGeNumIters, 480 DistributionMethod::Cyclic}; 481 cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 482 patterns.add<LinalgTilingPattern<MatmulOp>>( 483 context, 484 LinalgTilingOptions() 485 .setTileSizes({8, 8, 4}) 486 .setLoopType(LinalgTilingLoopType::ParallelLoops) 487 .setDistributionOptions(cyclicNprocsMixed2), 488 LinalgTransformationFilter( 489 Identifier::get("distribute5", context), 490 Identifier::get("after_distribute5", context))); 491 } 492 493 { 494 LinalgLoopDistributionOptions cyclicNprocsMixed3; 495 cyclicNprocsMixed3.distributionMethod = { 496 DistributionMethod::Cyclic, 497 DistributionMethod::CyclicNumProcsEqNumIters}; 498 cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 499 500 patterns.add<LinalgTilingPattern<MatmulOp>>( 501 context, 502 LinalgTilingOptions() 503 .setTileSizes({8, 8, 4}) 504 .setLoopType(LinalgTilingLoopType::ParallelLoops) 505 .setDistributionOptions(cyclicNprocsMixed3), 506 LinalgTransformationFilter( 507 Identifier::get("distribute6", context), 508 Identifier::get("after_distribute6", context))); 509 } 510 511 { 512 LinalgLoopDistributionOptions cyclicNprocsEqNiters; 513 cyclicNprocsEqNiters.distributionMethod.resize(2, 514 DistributionMethod::Cyclic); 515 cyclicNprocsEqNiters.procInfo = 516 getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>; 517 patterns.add<LinalgTilingPattern<MatmulOp>>( 518 context, 519 LinalgTilingOptions() 520 .setTileSizes({8, 8, 4}) 521 .setLoopType(LinalgTilingLoopType::Loops) 522 .setDistributionOptions(cyclicNprocsEqNiters), 523 LinalgTransformationFilter( 524 Identifier::get("tensors_distribute1", context), 525 Identifier::get("tensors_after_distribute1", context))); 526 } 527 } 528 529 static void 530 applyMatmulToVectorPatterns(FuncOp funcOp, 531 bool testMatmulToVectorPatterns1dTiling, 532 bool testMatmulToVectorPatterns2dTiling) { 533 MLIRContext *ctx = funcOp.getContext(); 534 SmallVector<RewritePatternSet, 4> stage1Patterns; 535 if (testMatmulToVectorPatterns1dTiling) { 536 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), 537 stage1Patterns); 538 } else if (testMatmulToVectorPatterns2dTiling) { 539 stage1Patterns.emplace_back( 540 ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>( 541 ctx, 542 LinalgTilingOptions() 543 .setTileSizes({768, 264, 768}) 544 .setInterchange({1, 2, 0}), 545 LinalgTransformationFilter(Identifier::get("START", ctx), 546 Identifier::get("L2", ctx)))); 547 fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx), 548 stage1Patterns); 549 } 550 { 551 // Canonicalization patterns 552 RewritePatternSet canonicalizationPatterns(funcOp.getContext()); 553 vector::populateVectorTransferPermutationMapLoweringPatterns( 554 canonicalizationPatterns); 555 vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); 556 stage1Patterns.push_back(std::move(canonicalizationPatterns)); 557 } 558 SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns; 559 llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); 560 FrozenRewritePatternSet stage2Patterns = 561 getLinalgTilingCanonicalizationPatterns(ctx); 562 (void)applyStagedPatterns(funcOp, frozenStage1Patterns, 563 std::move(stage2Patterns)); 564 } 565 566 static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { 567 RewritePatternSet forwardPattern(funcOp.getContext()); 568 forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext()); 569 forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext()); 570 (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); 571 } 572 573 static void applyLinalgToVectorPatterns(FuncOp funcOp) { 574 RewritePatternSet patterns(funcOp.getContext()); 575 patterns.add<LinalgVectorizationPattern>( 576 funcOp.getContext(), 577 LinalgTransformationFilter() 578 .addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>()); 579 populatePadTensorOpVectorizationPatterns(patterns); 580 populateConvolutionVectorizationPatterns(patterns); 581 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 582 } 583 584 static void applyDecomposeConvolutionPatterns(FuncOp funcOp) { 585 RewritePatternSet patterns(funcOp.getContext()); 586 populateDecomposeConvolutionPatterns(patterns); 587 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 588 } 589 590 static void applyPadTensorToGenericPatterns(FuncOp funcOp) { 591 RewritePatternSet patterns(funcOp.getContext()); 592 patterns.add<PadTensorOpTransformationPattern>(funcOp.getContext()); 593 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 594 } 595 596 static void applyGeneralizePadTensorPatterns(FuncOp funcOp) { 597 RewritePatternSet patterns(funcOp.getContext()); 598 patterns.add<GeneralizePadTensorOpPattern>(funcOp.getContext()); 599 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 600 } 601 602 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { 603 RewritePatternSet patterns(funcOp.getContext()); 604 patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext()); 605 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 606 } 607 608 // For now, just assume it is the zero of type. 609 // In the future, it should be the zero of type + op. 610 static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) { 611 auto t = getElementTypeOrSelf(op.get()); 612 return b.create<arith::ConstantOp>(op.getOwner()->getLoc(), t, 613 b.getZeroAttr(t)); 614 } 615 616 static void applyTilePattern(FuncOp funcOp, std::string loopType, 617 ArrayRef<int64_t> tileSizes, 618 ArrayRef<int64_t> paddedOperands, 619 ArrayRef<int64_t> nofoldOperands, 620 ArrayRef<int64_t> peeledLoops, 621 bool scalarizeDynamicDims) { 622 MLIRContext *context = funcOp.getContext(); 623 RewritePatternSet tilingPattern(context); 624 LinalgTilingLoopType type = 625 llvm::StringSwitch<LinalgTilingLoopType>(loopType) 626 .Case("for", LinalgTilingLoopType::Loops) 627 .Case("affine", LinalgTilingLoopType::AffineLoops) 628 .Case("parallel", LinalgTilingLoopType::ParallelLoops) 629 .Case("tiled_loop", LinalgTilingLoopType::TiledLoops); 630 auto linalgTilingOptions = linalg::LinalgTilingOptions() 631 .setPeeledLoops(peeledLoops) 632 .setLoopType(type); 633 if (scalarizeDynamicDims) { 634 linalgTilingOptions.scalarizeDynamicDims(); 635 assert(tileSizes.empty() && 636 "tileSizes and scalarizeDynamicDims is mutually exclusive"); 637 } else { 638 linalgTilingOptions.setTileSizes(tileSizes); 639 } 640 if (!paddedOperands.empty()) { 641 auto paddingFunc = [&](OpBuilder &b, 642 OpOperand &opOperand) -> FailureOr<Value> { 643 if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) 644 return failure(); 645 return getNeutralOfLinalgOp(b, opOperand); 646 }; 647 auto nofoldFunc = [&](OpOperand &opOperand) { 648 if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0) 649 return true; 650 return false; 651 }; 652 linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); 653 linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc); 654 } 655 tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>, 656 linalg::LinalgTilingPattern<linalg::GenericOp>>( 657 context, linalgTilingOptions, 658 linalg::LinalgTransformationFilter(Identifier::get("tile", context))); 659 (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); 660 } 661 662 static void applyPadPattern(FuncOp funcOp, ArrayRef<int64_t> packPaddings, 663 ArrayRef<int64_t> hoistPaddings) { 664 MLIRContext *context = funcOp.getContext(); 665 RewritePatternSet padPattern(context); 666 auto linalgPaddingOptions = linalg::LinalgPaddingOptions(); 667 auto packFunc = [&](OpOperand &opOperand) { 668 return opOperand.getOperandNumber() < packPaddings.size() 669 ? packPaddings[opOperand.getOperandNumber()] 670 : false; 671 }; 672 auto hoistingFunc = [&](OpOperand &opOperand) { 673 return opOperand.getOperandNumber() < hoistPaddings.size() 674 ? hoistPaddings[opOperand.getOperandNumber()] 675 : 0; 676 }; 677 linalgPaddingOptions.setPaddingValueComputationFunction(getNeutralOfLinalgOp); 678 linalgPaddingOptions.setPaddingNoFoldComputationFunction(packFunc); 679 linalgPaddingOptions.setPaddingHoistComputationFunction(hoistingFunc); 680 padPattern.add<LinalgPaddingPattern>( 681 context, linalgPaddingOptions, 682 LinalgTransformationFilter(Identifier::get("pad", context))); 683 (void)applyPatternsAndFoldGreedily(funcOp, std::move(padPattern)); 684 } 685 686 static void applyInterchangePattern(FuncOp funcOp, 687 ArrayRef<unsigned> interchangeVector) { 688 MLIRContext *context = funcOp.getContext(); 689 RewritePatternSet interchangePattern(context); 690 interchangePattern.add<GenericOpInterchangePattern>( 691 context, interchangeVector, 692 LinalgTransformationFilter(ArrayRef<Identifier>{}, 693 Identifier::get("interchange", context))); 694 (void)applyPatternsAndFoldGreedily(funcOp, std::move(interchangePattern)); 695 } 696 697 static constexpr char kPeeledLoopsLabel[] = "__peeled_loops__"; 698 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 699 700 namespace { 701 /// Peel TiledLoopOps, i.e., split them into two loops: One loop where the 702 /// `idx`-th loop contains only "full" iterations and a second loop for the 703 /// remaining partial iteration (if any). 704 struct TiledLoopPeelingPattern : public OpRewritePattern<TiledLoopOp> { 705 TiledLoopPeelingPattern(MLIRContext *ctx, int64_t idx, bool skipPartial) 706 : OpRewritePattern<TiledLoopOp>(ctx), idx(idx), skipPartial(skipPartial) { 707 } 708 709 LogicalResult matchAndRewrite(TiledLoopOp loopOp, 710 PatternRewriter &rewriter) const override { 711 SmallVector<int64_t> peeledLoops; 712 if (loopOp->hasAttr(kPeeledLoopsLabel)) { 713 auto attr = loopOp->getAttr(kPeeledLoopsLabel).cast<ArrayAttr>(); 714 peeledLoops = 715 llvm::to_vector<4>(llvm::map_range(attr, [](Attribute attr) { 716 return attr.cast<IntegerAttr>().getInt(); 717 })); 718 // Check if the loop was already peeled. 719 if (llvm::find(peeledLoops, idx) != peeledLoops.end()) 720 return failure(); 721 } 722 if (skipPartial && loopOp->hasAttr(kPartialIterationLabel)) 723 // No peeling of loop nests with a partial iteration. 724 return failure(); 725 726 if (static_cast<int64_t>(loopOp.iterator_types().size()) <= idx) 727 return failure(); 728 729 // Peel loop and canonicalize. 730 TiledLoopOp result; 731 if (failed(linalg::peelAndCanonicalizeTiledLoop(rewriter, loopOp, idx, 732 result))) 733 return failure(); 734 735 // Apply label, so that the same loop is not rewritten a second time. 736 peeledLoops.push_back(idx); 737 rewriter.updateRootInPlace(loopOp, [&]() { 738 loopOp->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 739 }); 740 result->setAttr(kPeeledLoopsLabel, rewriter.getI64ArrayAttr(peeledLoops)); 741 result->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 742 743 return success(); 744 } 745 746 /// Index of loop to peel. 747 int64_t idx; 748 749 /// If set to true, do not peel TiledLoopOps with a partial iteration. 750 bool skipPartial; 751 }; 752 } // namespace 753 754 static void applyTiledLoopPeelingPattern(FuncOp funcOp, 755 ArrayRef<unsigned> loops, 756 bool skipPartial) { 757 MLIRContext *ctx = funcOp.getContext(); 758 RewritePatternSet patterns(ctx); 759 for (unsigned idx : loops) 760 patterns.add<TiledLoopPeelingPattern>(ctx, idx, skipPartial); 761 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 762 763 // Drop the markers. 764 funcOp.walk([](TiledLoopOp op) { 765 op->removeAttr(kPeeledLoopsLabel); 766 op->removeAttr(kPartialIterationLabel); 767 }); 768 } 769 770 /// Apply transformations specified as patterns. 771 void TestLinalgTransforms::runOnFunction() { 772 auto lambda = [&](void *) { 773 getFunction().walk([](LinalgOp op) { 774 op->removeAttr(LinalgTransforms::kLinalgTransformMarker); 775 }); 776 }; 777 std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda}; 778 779 if (testPromotionOptions) { 780 RewritePatternSet patterns(&getContext()); 781 fillPromotionCallBackPatterns(&getContext(), patterns); 782 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 783 return; 784 } 785 if (testTileAndDistributionOptions) { 786 RewritePatternSet patterns(&getContext()); 787 fillTileAndDistributePatterns(&getContext(), patterns); 788 (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 789 return; 790 } 791 if (testPatterns) 792 return applyPatterns(getFunction()); 793 if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) 794 return applyMatmulToVectorPatterns(getFunction(), 795 testMatmulToVectorPatterns1dTiling, 796 testMatmulToVectorPatterns2dTiling); 797 if (testVectorTransferForwardingPatterns) 798 return applyVectorTransferForwardingPatterns(getFunction()); 799 if (testGenericToVectorPattern) 800 return applyLinalgToVectorPatterns(getFunction()); 801 if (testTransformPadTensor) 802 return applyPadTensorToGenericPatterns(getFunction()); 803 if (testGeneralizePadTensor) 804 return applyGeneralizePadTensorPatterns(getFunction()); 805 if (testSwapSubTensorPadTensor) 806 return applyExtractSliceOfPadTensorSwapPattern(getFunction()); 807 if (testTiledLoopPeeling.hasValue()) 808 return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, 809 skipPartial); 810 if (testTilePattern) 811 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 812 nofoldOperands, peeledLoops, 813 /*scalarizeDynamicDims=*/false); 814 if (testTileScalarizeDynamicDims) 815 return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, 816 nofoldOperands, 817 /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); 818 if (testHoistPadding) { 819 getFunction().walk([&](linalg::PadTensorOp padTensorOp) { 820 PadTensorOp hoistedOp; 821 FailureOr<Value> newResult = linalg::hoistPaddingOnTensors( 822 padTensorOp, testHoistPadding, hoistedOp); 823 if (succeeded(newResult)) { 824 padTensorOp.getResult().replaceAllUsesWith(newResult.getValue()); 825 padTensorOp->erase(); 826 } 827 }); 828 } 829 if (testPadPattern) 830 return applyPadPattern(getFunction(), packPaddings, hoistPaddings); 831 if (testInterchangePattern.hasValue()) 832 return applyInterchangePattern(getFunction(), testInterchangePattern); 833 if (testDecomposeConvolutionPattern) 834 return applyDecomposeConvolutionPatterns(getFunction()); 835 } 836 837 namespace mlir { 838 namespace test { 839 void registerTestLinalgTransforms() { 840 PassRegistration<TestLinalgTransforms>(); 841 } 842 } // namespace test 843 } // namespace mlir 844