1 //===- TestLoopMapping.cpp --- Parametric loop mapping pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to parametrically map scf.for loops to virtual
10 // processing element dimensions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/LoopUtils.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Pass/Pass.h"
19 
20 #include "llvm/ADT/SetVector.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 struct TestLoopMappingPass
26     : public PassWrapper<TestLoopMappingPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonc4182a8b0111::TestLoopMappingPass27   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopMappingPass)
28 
29   StringRef getArgument() const final {
30     return "test-mapping-to-processing-elements";
31   }
getDescription__anonc4182a8b0111::TestLoopMappingPass32   StringRef getDescription() const final {
33     return "test mapping a single loop on a virtual processor grid";
34   }
35   explicit TestLoopMappingPass() = default;
36 
getDependentDialects__anonc4182a8b0111::TestLoopMappingPass37   void getDependentDialects(DialectRegistry &registry) const override {
38     registry.insert<AffineDialect, scf::SCFDialect>();
39   }
40 
runOnOperation__anonc4182a8b0111::TestLoopMappingPass41   void runOnOperation() override {
42     // SSA values for the transformation are created out of thin air by
43     // unregistered "new_processor_id_and_range" operations. This is enough to
44     // emulate mapping conditions.
45     SmallVector<Value, 8> processorIds, numProcessors;
46     getOperation()->walk([&processorIds, &numProcessors](Operation *op) {
47       if (op->getName().getStringRef() != "new_processor_id_and_range")
48         return;
49       processorIds.push_back(op->getResult(0));
50       numProcessors.push_back(op->getResult(1));
51     });
52 
53     getOperation()->walk([&processorIds, &numProcessors](scf::ForOp op) {
54       // Ignore nested loops.
55       if (op->getParentRegion()->getParentOfType<scf::ForOp>())
56         return;
57       mapLoopToProcessorIds(op, processorIds, numProcessors);
58     });
59   }
60 };
61 } // namespace
62 
63 namespace mlir {
64 namespace test {
registerTestLoopMappingPass()65 void registerTestLoopMappingPass() { PassRegistration<TestLoopMappingPass>(); }
66 } // namespace test
67 } // namespace mlir
68