1 //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/IR/BuiltinTypes.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 struct TestMemRefStrideCalculation
17     : public PassWrapper<TestMemRefStrideCalculation,
18                          InterfacePass<SymbolOpInterface>> {
19   StringRef getArgument() const final {
20     return "test-memref-stride-calculation";
21   }
22   StringRef getDescription() const final {
23     return "Test operation constant folding";
24   }
25   void runOnOperation() override;
26 };
27 } // namespace
28 
29 /// Traverse AllocOp and compute strides of each MemRefType independently.
30 void TestMemRefStrideCalculation::runOnOperation() {
31   llvm::outs() << "Testing: " << getOperation().getName() << "\n";
32   getOperation().walk([&](memref::AllocOp allocOp) {
33     auto memrefType = allocOp.getResult().getType().cast<MemRefType>();
34     int64_t offset;
35     SmallVector<int64_t, 4> strides;
36     if (failed(getStridesAndOffset(memrefType, strides, offset))) {
37       llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
38                    << "strided form\n";
39       return;
40     }
41     llvm::outs() << "MemRefType offset: ";
42     if (offset == MemRefType::getDynamicStrideOrOffset())
43       llvm::outs() << "?";
44     else
45       llvm::outs() << offset;
46     llvm::outs() << " strides: ";
47     llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) {
48       if (v == MemRefType::getDynamicStrideOrOffset())
49         llvm::outs() << "?";
50       else
51         llvm::outs() << v;
52     });
53     llvm::outs() << "\n";
54   });
55   llvm::outs().flush();
56 }
57 
58 namespace mlir {
59 namespace test {
60 void registerTestMemRefStrideCalculation() {
61   PassRegistration<TestMemRefStrideCalculation>();
62 }
63 } // namespace test
64 } // namespace mlir
65