1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#  This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.
6
7import numpy as np
8import ctypes
9
10
11def make_nd_memref_descriptor(rank, dtype):
12    class MemRefDescriptor(ctypes.Structure):
13        """
14        Build an empty descriptor for the given rank/dtype, where rank>0.
15        """
16
17        _fields_ = [
18            ("allocated", ctypes.c_longlong),
19            ("aligned", ctypes.POINTER(dtype)),
20            ("offset", ctypes.c_longlong),
21            ("shape", ctypes.c_longlong * rank),
22            ("strides", ctypes.c_longlong * rank),
23        ]
24
25    return MemRefDescriptor
26
27
28def make_zero_d_memref_descriptor(dtype):
29    class MemRefDescriptor(ctypes.Structure):
30        """
31        Build an empty descriptor for the given dtype, where rank=0.
32        """
33
34        _fields_ = [
35            ("allocated", ctypes.c_longlong),
36            ("aligned", ctypes.POINTER(dtype)),
37            ("offset", ctypes.c_longlong),
38        ]
39
40    return MemRefDescriptor
41
42
43class UnrankedMemRefDescriptor(ctypes.Structure):
44    """ Creates a ctype struct for memref descriptor"""
45
46    _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
47
48
49def get_ranked_memref_descriptor(nparray):
50    """
51    Return a ranked memref descriptor for the given numpy array.
52    """
53    if nparray.ndim == 0:
54        x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))()
55        x.allocated = nparray.ctypes.data
56        x.aligned = nparray.ctypes.data_as(
57            ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
58        )
59        x.offset = ctypes.c_longlong(0)
60        return x
61
62    x = make_nd_memref_descriptor(
63        nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype)
64    )()
65    x.allocated = nparray.ctypes.data
66    x.aligned = nparray.ctypes.data_as(
67        ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype))
68    )
69    x.offset = ctypes.c_longlong(0)
70    x.shape = nparray.ctypes.shape
71
72    # Numpy uses byte quantities to express strides, MLIR OTOH uses the
73    # torch abstraction which specifies strides in terms of elements.
74    strides_ctype_t = ctypes.c_longlong * nparray.ndim
75    x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
76    return x
77
78
79def get_unranked_memref_descriptor(nparray):
80    """
81    Return a generic/unranked memref descriptor for the given numpy array.
82    """
83    d = UnrankedMemRefDescriptor()
84    d.rank = nparray.ndim
85    x = get_ranked_memref_descriptor(nparray)
86    d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
87    return d
88
89
90def unranked_memref_to_numpy(unranked_memref, np_dtype):
91    """
92    Converts unranked memrefs to numpy arrays.
93    """
94    descriptor = make_nd_memref_descriptor(
95        unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype)
96    )
97    val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
98    np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
99    strided_arr = np.lib.stride_tricks.as_strided(
100        np_arr,
101        np.ctypeslib.as_array(val[0].shape),
102        np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
103    )
104    return strided_arr
105
106
107def ranked_memref_to_numpy(ranked_memref):
108    """
109    Converts ranked memrefs to numpy arrays.
110    """
111    np_arr = np.ctypeslib.as_array(
112        ranked_memref[0].aligned, shape=ranked_memref[0].shape
113    )
114    strided_arr = np.lib.stride_tricks.as_strided(
115        np_arr,
116        np.ctypeslib.as_array(ranked_memref[0].shape),
117        np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
118    )
119    return strided_arr
120