1 //===- TestComposeSubView.cpp - Test composed subviews --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
11 
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 struct TestMultiBufferingPass
19     : public PassWrapper<TestMultiBufferingPass, OperationPass<>> {
20   TestMultiBufferingPass() = default;
21   TestMultiBufferingPass(const TestMultiBufferingPass &pass)
22       : PassWrapper(pass) {}
23   void getDependentDialects(DialectRegistry &registry) const override {
24     registry.insert<AffineDialect>();
25   }
26   StringRef getArgument() const final { return "test-multi-buffering"; }
27   StringRef getDescription() const final {
28     return "Test multi buffering transformation";
29   }
30   void runOnOperation() override;
31   Option<unsigned> multiplier{
32       *this, "multiplier",
33       llvm::cl::desc(
34           "Decide how many versions of the buffer should be created,"),
35       llvm::cl::init(2)};
36 };
37 
38 void TestMultiBufferingPass::runOnOperation() {
39   SmallVector<memref::AllocOp> allocs;
40   getOperation()->walk(
41       [&allocs](memref::AllocOp alloc) { allocs.push_back(alloc); });
42   for (memref::AllocOp alloc : allocs)
43     (void)multiBuffer(alloc, multiplier);
44 }
45 } // namespace
46 
47 namespace mlir {
48 namespace test {
49 void registerTestMultiBuffering() {
50   PassRegistration<TestMultiBufferingPass>();
51 }
52 } // namespace test
53 } // namespace mlir
54