1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. 6 7import numpy as np 8import ctypes 9 10 11class C128(ctypes.Structure): 12 """A ctype representation for MLIR's Double Complex.""" 13 _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] 14 15 16class C64(ctypes.Structure): 17 """A ctype representation for MLIR's Float Complex.""" 18 _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] 19 20 21class F16(ctypes.Structure): 22 """A ctype representation for MLIR's Float16.""" 23 _fields_ = [("f16", ctypes.c_int16)] 24 25 26def as_ctype(dtp): 27 """Converts dtype to ctype.""" 28 if dtp is np.dtype(np.complex128): 29 return C128 30 if dtp is np.dtype(np.complex64): 31 return C64 32 if dtp is np.dtype(np.float16): 33 return F16 34 return np.ctypeslib.as_ctypes_type(dtp) 35 36 37def to_numpy(array): 38 """Converts ctypes array back to numpy dtype array.""" 39 if array.dtype == C128: 40 return array.view("complex128") 41 if array.dtype == C64: 42 return array.view("complex64") 43 if array.dtype == F16: 44 return array.view("float16") 45 return array 46 47 48def make_nd_memref_descriptor(rank, dtype): 49 50 class MemRefDescriptor(ctypes.Structure): 51 """Builds an empty descriptor for the given rank/dtype, where rank>0.""" 52 53 _fields_ = [ 54 ("allocated", ctypes.c_longlong), 55 ("aligned", ctypes.POINTER(dtype)), 56 ("offset", ctypes.c_longlong), 57 ("shape", ctypes.c_longlong * rank), 58 ("strides", ctypes.c_longlong * rank), 59 ] 60 61 return MemRefDescriptor 62 63 64def make_zero_d_memref_descriptor(dtype): 65 66 class MemRefDescriptor(ctypes.Structure): 67 """Builds an empty descriptor for the given dtype, where rank=0.""" 68 69 _fields_ = [ 70 ("allocated", ctypes.c_longlong), 71 ("aligned", ctypes.POINTER(dtype)), 72 ("offset", ctypes.c_longlong), 73 ] 74 75 return MemRefDescriptor 76 77 78class UnrankedMemRefDescriptor(ctypes.Structure): 79 """Creates a ctype struct for memref descriptor""" 80 _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] 81 82 83def get_ranked_memref_descriptor(nparray): 84 """Returns a ranked memref descriptor for the given numpy array.""" 85 ctp = as_ctype(nparray.dtype) 86 if nparray.ndim == 0: 87 x = make_zero_d_memref_descriptor(ctp)() 88 x.allocated = nparray.ctypes.data 89 x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) 90 x.offset = ctypes.c_longlong(0) 91 return x 92 93 x = make_nd_memref_descriptor(nparray.ndim, ctp)() 94 x.allocated = nparray.ctypes.data 95 x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) 96 x.offset = ctypes.c_longlong(0) 97 x.shape = nparray.ctypes.shape 98 99 # Numpy uses byte quantities to express strides, MLIR OTOH uses the 100 # torch abstraction which specifies strides in terms of elements. 101 strides_ctype_t = ctypes.c_longlong * nparray.ndim 102 x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) 103 return x 104 105 106def get_unranked_memref_descriptor(nparray): 107 """Returns a generic/unranked memref descriptor for the given numpy array.""" 108 d = UnrankedMemRefDescriptor() 109 d.rank = nparray.ndim 110 x = get_ranked_memref_descriptor(nparray) 111 d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) 112 return d 113 114 115def unranked_memref_to_numpy(unranked_memref, np_dtype): 116 """Converts unranked memrefs to numpy arrays.""" 117 ctp = as_ctype(np_dtype) 118 descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) 119 val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) 120 np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) 121 strided_arr = np.lib.stride_tricks.as_strided( 122 np_arr, 123 np.ctypeslib.as_array(val[0].shape), 124 np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, 125 ) 126 return to_numpy(strided_arr) 127 128 129def ranked_memref_to_numpy(ranked_memref): 130 """Converts ranked memrefs to numpy arrays.""" 131 np_arr = np.ctypeslib.as_array( 132 ranked_memref[0].aligned, shape=ranked_memref[0].shape) 133 strided_arr = np.lib.stride_tricks.as_strided( 134 np_arr, 135 np.ctypeslib.as_array(ranked_memref[0].shape), 136 np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, 137 ) 138 return to_numpy(strided_arr) 139