1d115a48eSMaheshRavishankar# RUN: %PYTHON %s | FileCheck %s
2d115a48eSMaheshRavishankar
3d115a48eSMaheshRavishankarfrom mlir.ir import *
4d115a48eSMaheshRavishankarimport mlir.dialects.arith as arith
5*36550692SRiver Riddleimport mlir.dialects.func as func
6d115a48eSMaheshRavishankarimport mlir.dialects.tensor as tensor
7d115a48eSMaheshRavishankar
8d115a48eSMaheshRavishankar
9d115a48eSMaheshRavishankardef run(f):
10d115a48eSMaheshRavishankar  print("\nTEST:", f.__name__)
11d115a48eSMaheshRavishankar  f()
12d115a48eSMaheshRavishankar  return f
13d115a48eSMaheshRavishankar
14d115a48eSMaheshRavishankar
15d115a48eSMaheshRavishankar# CHECK-LABEL: TEST: testDimOp
16d115a48eSMaheshRavishankar@run
17d115a48eSMaheshRavishankardef testDimOp():
18d115a48eSMaheshRavishankar  with Context() as ctx, Location.unknown():
19d115a48eSMaheshRavishankar    module = Module.create()
20d115a48eSMaheshRavishankar    f32Type = F32Type.get()
21d115a48eSMaheshRavishankar    indexType = IndexType.get()
22d115a48eSMaheshRavishankar    with InsertionPoint(module.body):
23d115a48eSMaheshRavishankar
24*36550692SRiver Riddle      @func.FuncOp.from_py_func(RankedTensorType.get((-1, -1), f32Type))
25d115a48eSMaheshRavishankar      #      CHECK: func @tensor_static_dim
26d115a48eSMaheshRavishankar      # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
27d115a48eSMaheshRavishankar      #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
28d115a48eSMaheshRavishankar      #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
29d115a48eSMaheshRavishankar      #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
30d115a48eSMaheshRavishankar      #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
31d115a48eSMaheshRavishankar      #      CHECK:   return %[[D0]], %[[D1]]
32d115a48eSMaheshRavishankar      def tensor_static_dim(t):
33d115a48eSMaheshRavishankar        c0 = arith.ConstantOp(indexType, 0)
34d115a48eSMaheshRavishankar        c1 = arith.ConstantOp(indexType, 1)
35d115a48eSMaheshRavishankar        d0 = tensor.DimOp(t, c0)
36d115a48eSMaheshRavishankar        d1 = tensor.DimOp(t, c1)
37d115a48eSMaheshRavishankar        return [d0.result, d1.result]
38d115a48eSMaheshRavishankar
39d115a48eSMaheshRavishankar    print(module)
40