1 //===- TestMemRefStrideCalculation.cpp - Pass to test strides computation--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/MemRef/IR/MemRef.h" 10 #include "mlir/IR/BuiltinTypes.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 struct TestMemRefStrideCalculation 17 : public PassWrapper<TestMemRefStrideCalculation, FunctionPass> { 18 void runOnFunction() override; 19 }; 20 } // end anonymous namespace 21 22 /// Traverse AllocOp and compute strides of each MemRefType independently. 23 void TestMemRefStrideCalculation::runOnFunction() { 24 llvm::outs() << "Testing: " << getFunction().getName() << "\n"; 25 getFunction().walk([&](memref::AllocOp allocOp) { 26 auto memrefType = allocOp.getResult().getType().cast<MemRefType>(); 27 int64_t offset; 28 SmallVector<int64_t, 4> strides; 29 if (failed(getStridesAndOffset(memrefType, strides, offset))) { 30 llvm::outs() << "MemRefType " << memrefType << " cannot be converted to " 31 << "strided form\n"; 32 return; 33 } 34 llvm::outs() << "MemRefType offset: "; 35 if (offset == MemRefType::getDynamicStrideOrOffset()) 36 llvm::outs() << "?"; 37 else 38 llvm::outs() << offset; 39 llvm::outs() << " strides: "; 40 llvm::interleaveComma(strides, llvm::outs(), [&](int64_t v) { 41 if (v == MemRefType::getDynamicStrideOrOffset()) 42 llvm::outs() << "?"; 43 else 44 llvm::outs() << v; 45 }); 46 llvm::outs() << "\n"; 47 }); 48 llvm::outs().flush(); 49 } 50 51 namespace mlir { 52 namespace test { 53 void registerTestMemRefStrideCalculation() { 54 PassRegistration<TestMemRefStrideCalculation> pass( 55 "test-memref-stride-calculation", "Test operation constant folding"); 56 } 57 } // namespace test 58 } // namespace mlir 59