1 //===- TestComposeSubView.cpp - Test composed subviews --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
11 
12 #include "mlir/Dialect/MemRef/IR/MemRef.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 namespace {
18 struct TestMultiBufferingPass
19     : public PassWrapper<TestMultiBufferingPass, OperationPass<>> {
20   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMultiBufferingPass)
21 
22   TestMultiBufferingPass() = default;
23   TestMultiBufferingPass(const TestMultiBufferingPass &pass)
24       : PassWrapper(pass) {}
25   void getDependentDialects(DialectRegistry &registry) const override {
26     registry.insert<AffineDialect>();
27   }
28   StringRef getArgument() const final { return "test-multi-buffering"; }
29   StringRef getDescription() const final {
30     return "Test multi buffering transformation";
31   }
32   void runOnOperation() override;
33   Option<unsigned> multiplier{
34       *this, "multiplier",
35       llvm::cl::desc(
36           "Decide how many versions of the buffer should be created,"),
37       llvm::cl::init(2)};
38 };
39 
40 void TestMultiBufferingPass::runOnOperation() {
41   SmallVector<memref::AllocOp> allocs;
42   getOperation()->walk(
43       [&allocs](memref::AllocOp alloc) { allocs.push_back(alloc); });
44   for (memref::AllocOp alloc : allocs)
45     (void)multiBuffer(alloc, multiplier);
46 }
47 } // namespace
48 
49 namespace mlir {
50 namespace test {
51 void registerTestMultiBuffering() {
52   PassRegistration<TestMultiBufferingPass>();
53 }
54 } // namespace test
55 } // namespace mlir
56