1*a70aa7bbSRiver Riddle //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2*a70aa7bbSRiver Riddle //
3*a70aa7bbSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*a70aa7bbSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*a70aa7bbSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*a70aa7bbSRiver Riddle //
7*a70aa7bbSRiver Riddle //===----------------------------------------------------------------------===//
8*a70aa7bbSRiver Riddle 
9*a70aa7bbSRiver Riddle #include "PassDetail.h"
10*a70aa7bbSRiver Riddle #include "mlir/Dialect/SCF/Passes.h"
11*a70aa7bbSRiver Riddle #include "mlir/Dialect/SCF/SCF.h"
12*a70aa7bbSRiver Riddle #include "mlir/Dialect/SCF/Utils.h"
13*a70aa7bbSRiver Riddle #include "mlir/Transforms/RegionUtils.h"
14*a70aa7bbSRiver Riddle #include "llvm/Support/CommandLine.h"
15*a70aa7bbSRiver Riddle #include "llvm/Support/Debug.h"
16*a70aa7bbSRiver Riddle 
17*a70aa7bbSRiver Riddle #define DEBUG_TYPE "parallel-loop-collapsing"
18*a70aa7bbSRiver Riddle 
19*a70aa7bbSRiver Riddle using namespace mlir;
20*a70aa7bbSRiver Riddle 
21*a70aa7bbSRiver Riddle namespace {
22*a70aa7bbSRiver Riddle struct ParallelLoopCollapsing
23*a70aa7bbSRiver Riddle     : public SCFParallelLoopCollapsingBase<ParallelLoopCollapsing> {
24*a70aa7bbSRiver Riddle   void runOnOperation() override {
25*a70aa7bbSRiver Riddle     Operation *module = getOperation();
26*a70aa7bbSRiver Riddle 
27*a70aa7bbSRiver Riddle     module->walk([&](scf::ParallelOp op) {
28*a70aa7bbSRiver Riddle       // The common case for GPU dialect will be simplifying the ParallelOp to 3
29*a70aa7bbSRiver Riddle       // arguments, so we do that here to simplify things.
30*a70aa7bbSRiver Riddle       llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
31*a70aa7bbSRiver Riddle       if (!clCollapsedIndices0.empty())
32*a70aa7bbSRiver Riddle         combinedLoops.push_back(clCollapsedIndices0);
33*a70aa7bbSRiver Riddle       if (!clCollapsedIndices1.empty())
34*a70aa7bbSRiver Riddle         combinedLoops.push_back(clCollapsedIndices1);
35*a70aa7bbSRiver Riddle       if (!clCollapsedIndices2.empty())
36*a70aa7bbSRiver Riddle         combinedLoops.push_back(clCollapsedIndices2);
37*a70aa7bbSRiver Riddle       collapseParallelLoops(op, combinedLoops);
38*a70aa7bbSRiver Riddle     });
39*a70aa7bbSRiver Riddle   }
40*a70aa7bbSRiver Riddle };
41*a70aa7bbSRiver Riddle } // namespace
42*a70aa7bbSRiver Riddle 
43*a70aa7bbSRiver Riddle std::unique_ptr<Pass> mlir::createParallelLoopCollapsingPass() {
44*a70aa7bbSRiver Riddle   return std::make_unique<ParallelLoopCollapsing>();
45*a70aa7bbSRiver Riddle }
46