1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.std as std
5import mlir.dialects.memref as memref
6
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11
12
13# CHECK-LABEL: TEST: testSubViewAccessors
14def testSubViewAccessors():
15  ctx = Context()
16  module = Module.parse(
17      r"""
18    func @f1(%arg0: memref<?x?xf32>) {
19      %0 = arith.constant 0 : index
20      %1 = arith.constant 1 : index
21      %2 = arith.constant 2 : index
22      %3 = arith.constant 3 : index
23      %4 = arith.constant 4 : index
24      %5 = arith.constant 5 : index
25      memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
26      return
27    }
28  """, ctx)
29  func_body = module.body.operations[0].regions[0].blocks[0]
30  subview = func_body.operations[6]
31
32  assert subview.source == subview.operands[0]
33  assert len(subview.offsets) == 2
34  assert len(subview.sizes) == 2
35  assert len(subview.strides) == 2
36  assert subview.result == subview.results[0]
37
38  # CHECK: SubViewOp
39  print(type(subview).__name__)
40
41  # CHECK: constant 0
42  print(subview.offsets[0])
43  # CHECK: constant 1
44  print(subview.offsets[1])
45  # CHECK: constant 2
46  print(subview.sizes[0])
47  # CHECK: constant 3
48  print(subview.sizes[1])
49  # CHECK: constant 4
50  print(subview.strides[0])
51  # CHECK: constant 5
52  print(subview.strides[1])
53
54
55run(testSubViewAccessors)
56