1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import numpy as np 5import mlir.dialects.builtin as builtin 6import mlir.dialects.shape as shape 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 return f 13 14 15# CHECK-LABEL: TEST: testConstShape 16@run 17def testConstShape(): 18 with Context() as ctx, Location.unknown(): 19 module = Module.create() 20 f32 = F32Type.get() 21 indexT = IndexType.get() 22 with InsertionPoint(module.body): 23 @builtin.FuncOp.from_py_func( 24 RankedTensorType.get((12, -1), f32)) 25 def const_shape_tensor(arg): 26 return shape.ConstShapeOp(RankedTensorType.get((2,), indexT), 27 DenseElementsAttr.get(np.array([10, 20]))) 28 29 # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) 30 # CHECK: shape.const_shape [10, 20] : tensor<2xindex> 31 print(module) 32 33