13f2891dbSJacques Pienaar# RUN: %PYTHON %s | FileCheck %s
23f2891dbSJacques Pienaar
33f2891dbSJacques Pienaarfrom mlir.ir import *
43f2891dbSJacques Pienaarimport numpy as np
536550692SRiver Riddleimport mlir.dialects.func as func
63f2891dbSJacques Pienaarimport mlir.dialects.shape as shape
73f2891dbSJacques Pienaar
83f2891dbSJacques Pienaar
93f2891dbSJacques Pienaardef run(f):
103f2891dbSJacques Pienaar  print("\nTEST:", f.__name__)
113f2891dbSJacques Pienaar  f()
123f2891dbSJacques Pienaar  return f
133f2891dbSJacques Pienaar
143f2891dbSJacques Pienaar
153f2891dbSJacques Pienaar# CHECK-LABEL: TEST: testConstShape
163f2891dbSJacques Pienaar@run
173f2891dbSJacques Pienaardef testConstShape():
183f2891dbSJacques Pienaar  with Context() as ctx, Location.unknown():
193f2891dbSJacques Pienaar    module = Module.create()
203f2891dbSJacques Pienaar    f32 = F32Type.get()
213f2891dbSJacques Pienaar    with InsertionPoint(module.body):
2236550692SRiver Riddle      @func.FuncOp.from_py_func(
233f2891dbSJacques Pienaar          RankedTensorType.get((12, -1), f32))
243f2891dbSJacques Pienaar      def const_shape_tensor(arg):
25ace1d0adSStella Laurenzo        return shape.ConstShapeOp(
26*057863a9SStella Stamenova          DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get()))
273f2891dbSJacques Pienaar
283f2891dbSJacques Pienaar    # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
293f2891dbSJacques Pienaar    # CHECK: shape.const_shape [10, 20] : tensor<2xindex>
303f2891dbSJacques Pienaar    print(module)
31