1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. 6 7import numpy as np 8import ctypes 9 10 11def make_nd_memref_descriptor(rank, dtype): 12 class MemRefDescriptor(ctypes.Structure): 13 """ 14 Build an empty descriptor for the given rank/dtype, where rank>0. 15 """ 16 17 _fields_ = [ 18 ("allocated", ctypes.c_longlong), 19 ("aligned", ctypes.POINTER(dtype)), 20 ("offset", ctypes.c_longlong), 21 ("shape", ctypes.c_longlong * rank), 22 ("strides", ctypes.c_longlong * rank), 23 ] 24 25 return MemRefDescriptor 26 27 28def make_zero_d_memref_descriptor(dtype): 29 class MemRefDescriptor(ctypes.Structure): 30 """ 31 Build an empty descriptor for the given dtype, where rank=0. 32 """ 33 34 _fields_ = [ 35 ("allocated", ctypes.c_longlong), 36 ("aligned", ctypes.POINTER(dtype)), 37 ("offset", ctypes.c_longlong), 38 ] 39 40 return MemRefDescriptor 41 42 43class UnrankedMemRefDescriptor(ctypes.Structure): 44 """ Creates a ctype struct for memref descriptor""" 45 46 _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] 47 48 49def get_ranked_memref_descriptor(nparray): 50 """ 51 Return a ranked memref descriptor for the given numpy array. 52 """ 53 if nparray.ndim == 0: 54 x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() 55 x.allocated = nparray.ctypes.data 56 x.aligned = nparray.ctypes.data_as( 57 ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) 58 ) 59 x.offset = ctypes.c_longlong(0) 60 return x 61 62 x = make_nd_memref_descriptor( 63 nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) 64 )() 65 x.allocated = nparray.ctypes.data 66 x.aligned = nparray.ctypes.data_as( 67 ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) 68 ) 69 x.offset = ctypes.c_longlong(0) 70 x.shape = nparray.ctypes.shape 71 72 # Numpy uses byte quantities to express strides, MLIR OTOH uses the 73 # torch abstraction which specifies strides in terms of elements. 74 strides_ctype_t = ctypes.c_longlong * nparray.ndim 75 x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) 76 return x 77 78 79def get_unranked_memref_descriptor(nparray): 80 """ 81 Return a generic/unranked memref descriptor for the given numpy array. 82 """ 83 d = UnrankedMemRefDescriptor() 84 d.rank = nparray.ndim 85 x = get_ranked_memref_descriptor(nparray) 86 d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) 87 return d 88 89 90def unranked_memref_to_numpy(unranked_memref, np_dtype): 91 """ 92 Converts unranked memrefs to numpy arrays. 93 """ 94 descriptor = make_nd_memref_descriptor( 95 unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) 96 ) 97 val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) 98 np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) 99 strided_arr = np.lib.stride_tricks.as_strided( 100 np_arr, 101 np.ctypeslib.as_array(val[0].shape), 102 np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, 103 ) 104 return strided_arr 105 106 107def ranked_memref_to_numpy(ranked_memref): 108 """ 109 Converts ranked memrefs to numpy arrays. 110 """ 111 np_arr = np.ctypeslib.as_array( 112 ranked_memref[0].aligned, shape=ranked_memref[0].shape 113 ) 114 strided_arr = np.lib.stride_tricks.as_strided( 115 np_arr, 116 np.ctypeslib.as_array(ranked_memref[0].shape), 117 np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, 118 ) 119 return strided_arr 120