1# RUN: %PYTHON %s | FileCheck %s 2 3from mlir.ir import * 4import mlir.dialects.arith as arith 5import mlir.dialects.func as func 6import mlir.dialects.tensor as tensor 7 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 return f 13 14 15# CHECK-LABEL: TEST: testDimOp 16@run 17def testDimOp(): 18 with Context() as ctx, Location.unknown(): 19 module = Module.create() 20 f32Type = F32Type.get() 21 indexType = IndexType.get() 22 with InsertionPoint(module.body): 23 24 @func.FuncOp.from_py_func(RankedTensorType.get((-1, -1), f32Type)) 25 # CHECK: func @tensor_static_dim 26 # CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32> 27 # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index 28 # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index 29 # CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] 30 # CHECK: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] 31 # CHECK: return %[[D0]], %[[D1]] 32 def tensor_static_dim(t): 33 c0 = arith.ConstantOp(indexType, 0) 34 c1 = arith.ConstantOp(indexType, 1) 35 d0 = tensor.DimOp(t, c0) 36 d1 = tensor.DimOp(t, c1) 37 return [d0.result, d1.result] 38 39 print(module) 40