1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.func as func
5import mlir.dialects.memref as memref
6
7
8def run(f):
9  print("\nTEST:", f.__name__)
10  f()
11  return f
12
13
14# CHECK-LABEL: TEST: testSubViewAccessors
15@run
16def testSubViewAccessors():
17  ctx = Context()
18  module = Module.parse(
19      r"""
20    func.func @f1(%arg0: memref<?x?xf32>) {
21      %0 = arith.constant 0 : index
22      %1 = arith.constant 1 : index
23      %2 = arith.constant 2 : index
24      %3 = arith.constant 3 : index
25      %4 = arith.constant 4 : index
26      %5 = arith.constant 5 : index
27      memref.subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
28      return
29    }
30  """, ctx)
31  func_body = module.body.operations[0].regions[0].blocks[0]
32  subview = func_body.operations[6]
33
34  assert subview.source == subview.operands[0]
35  assert len(subview.offsets) == 2
36  assert len(subview.sizes) == 2
37  assert len(subview.strides) == 2
38  assert subview.result == subview.results[0]
39
40  # CHECK: SubViewOp
41  print(type(subview).__name__)
42
43  # CHECK: constant 0
44  print(subview.offsets[0])
45  # CHECK: constant 1
46  print(subview.offsets[1])
47  # CHECK: constant 2
48  print(subview.sizes[0])
49  # CHECK: constant 3
50  print(subview.sizes[1])
51  # CHECK: constant 4
52  print(subview.strides[0])
53  # CHECK: constant 5
54  print(subview.strides[1])
55
56
57# CHECK-LABEL: TEST: testCustomBuidlers
58@run
59def testCustomBuidlers():
60  with Context() as ctx, Location.unknown(ctx):
61    module = Module.parse(r"""
62      func.func @f1(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) {
63        return
64      }
65    """)
66    f = module.body.operations[0]
67    func_body = f.regions[0].blocks[0]
68    with InsertionPoint.at_block_terminator(func_body):
69      memref.LoadOp(f.arguments[0], f.arguments[1:])
70
71    # CHECK: func @f1(%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
72    # CHECK: memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]]
73    print(module)
74    assert module.operation.verify()
75