13f2891dbSJacques Pienaar# RUN: %PYTHON %s | FileCheck %s 23f2891dbSJacques Pienaar 33f2891dbSJacques Pienaarfrom mlir.ir import * 43f2891dbSJacques Pienaarimport numpy as np 536550692SRiver Riddleimport mlir.dialects.func as func 63f2891dbSJacques Pienaarimport mlir.dialects.shape as shape 73f2891dbSJacques Pienaar 83f2891dbSJacques Pienaar 93f2891dbSJacques Pienaardef run(f): 103f2891dbSJacques Pienaar print("\nTEST:", f.__name__) 113f2891dbSJacques Pienaar f() 123f2891dbSJacques Pienaar return f 133f2891dbSJacques Pienaar 143f2891dbSJacques Pienaar 153f2891dbSJacques Pienaar# CHECK-LABEL: TEST: testConstShape 163f2891dbSJacques Pienaar@run 173f2891dbSJacques Pienaardef testConstShape(): 183f2891dbSJacques Pienaar with Context() as ctx, Location.unknown(): 193f2891dbSJacques Pienaar module = Module.create() 203f2891dbSJacques Pienaar f32 = F32Type.get() 213f2891dbSJacques Pienaar with InsertionPoint(module.body): 2236550692SRiver Riddle @func.FuncOp.from_py_func( 233f2891dbSJacques Pienaar RankedTensorType.get((12, -1), f32)) 243f2891dbSJacques Pienaar def const_shape_tensor(arg): 25ace1d0adSStella Laurenzo return shape.ConstShapeOp( 26*057863a9SStella Stamenova DenseElementsAttr.get(np.array([10, 20], dtype=np.int64), type=IndexType.get())) 273f2891dbSJacques Pienaar 283f2891dbSJacques Pienaar # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) 293f2891dbSJacques Pienaar # CHECK: shape.const_shape [10, 20] : tensor<2xindex> 303f2891dbSJacques Pienaar print(module) 31