# RUN: %PYTHON %s | FileCheck %s

from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.builtin as builtin
import mlir.dialects.tensor as tensor


def run(f):
  print("\nTEST:", f.__name__)
  f()
  return f


# CHECK-LABEL: TEST: testDimOp
@run
def testDimOp():
  with Context() as ctx, Location.unknown():
    module = Module.create()
    f32Type = F32Type.get()
    indexType = IndexType.get()
    with InsertionPoint(module.body):

      @builtin.FuncOp.from_py_func(RankedTensorType.get((-1, -1), f32Type))
      #      CHECK: func @tensor_static_dim
      # CHECK-SAME:     %[[ARG0:.+]]: tensor<?x?xf32>
      #  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
      #  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
      #      CHECK:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
      #      CHECK:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
      #      CHECK:   return %[[D0]], %[[D1]]
      def tensor_static_dim(t):
        c0 = arith.ConstantOp(indexType, 0)
        c1 = arith.ConstantOp(indexType, 1)
        d0 = tensor.DimOp(t, c0)
        d1 = tensor.DimOp(t, c1)
        return [d0.result, d1.result]

    print(module)
