1# RUN: %PYTHON %s | FileCheck %s
2
3from mlir.ir import *
4import mlir.dialects.arith as arith
5import mlir.dialects.func as func
6import mlir.dialects.tensor as tensor
7
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  return f
13
14
15# CHECK-LABEL: TEST: testDimOp
16@run
17def testDimOp():
18  with Context() as ctx, Location.unknown():
19    module = Module.create()
20    f32Type = F32Type.get()
21    indexType = IndexType.get()
22    with InsertionPoint(module.body):
23
24      @func.FuncOp.from_py_func(RankedTensorType.get((-1, -1), f32Type))
25      #      CHECK: func @tensor_static_dim
26      # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
27      #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
28      #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
29      #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
30      #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
31      #      CHECK:   return %[[D0]], %[[D1]]
32      def tensor_static_dim(t):
33        c0 = arith.ConstantOp(indexType, 0)
34        c1 = arith.ConstantOp(indexType, 1)
35        d0 = tensor.DimOp(t, c0)
36        d1 = tensor.DimOp(t, c1)
37        return [d0.result, d1.result]
38
39    print(module)
40