1a70aa7bbSRiver Riddle //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2a70aa7bbSRiver Riddle //
3a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a70aa7bbSRiver Riddle //
7a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8a70aa7bbSRiver Riddle 
9a70aa7bbSRiver Riddle #include "PassDetail.h"
10*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
11*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/Transforms/Passes.h"
12f40475c7SAdrian Kuegel #include "mlir/Dialect/SCF/Utils/Utils.h"
13a70aa7bbSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
14a70aa7bbSRiver Riddle #include "llvm/Support/CommandLine.h"
15a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
16a70aa7bbSRiver Riddle 
17a70aa7bbSRiver Riddle #define DEBUG_TYPE "parallel-loop-collapsing"
18a70aa7bbSRiver Riddle 
19a70aa7bbSRiver Riddle using namespace mlir;
20a70aa7bbSRiver Riddle 
21a70aa7bbSRiver Riddle namespace {
22a70aa7bbSRiver Riddle struct ParallelLoopCollapsing
23a70aa7bbSRiver Riddle     : public SCFParallelLoopCollapsingBase<ParallelLoopCollapsing> {
runOnOperation__anon490b703a0111::ParallelLoopCollapsing24a70aa7bbSRiver Riddle   void runOnOperation() override {
25a70aa7bbSRiver Riddle     Operation *module = getOperation();
26a70aa7bbSRiver Riddle 
27a70aa7bbSRiver Riddle     module->walk([&](scf::ParallelOp op) {
28a70aa7bbSRiver Riddle       // The common case for GPU dialect will be simplifying the ParallelOp to 3
29a70aa7bbSRiver Riddle       // arguments, so we do that here to simplify things.
30a70aa7bbSRiver Riddle       llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
31a70aa7bbSRiver Riddle       if (!clCollapsedIndices0.empty())
32a70aa7bbSRiver Riddle         combinedLoops.push_back(clCollapsedIndices0);
33a70aa7bbSRiver Riddle       if (!clCollapsedIndices1.empty())
34a70aa7bbSRiver Riddle         combinedLoops.push_back(clCollapsedIndices1);
35a70aa7bbSRiver Riddle       if (!clCollapsedIndices2.empty())
36a70aa7bbSRiver Riddle         combinedLoops.push_back(clCollapsedIndices2);
37a70aa7bbSRiver Riddle       collapseParallelLoops(op, combinedLoops);
38a70aa7bbSRiver Riddle     });
39a70aa7bbSRiver Riddle   }
40a70aa7bbSRiver Riddle };
41a70aa7bbSRiver Riddle } // namespace
42a70aa7bbSRiver Riddle 
createParallelLoopCollapsingPass()43a70aa7bbSRiver Riddle std::unique_ptr<Pass> mlir::createParallelLoopCollapsingPass() {
44a70aa7bbSRiver Riddle   return std::make_unique<ParallelLoopCollapsing>();
45a70aa7bbSRiver Riddle }
46