1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#  This file contains functions to convert between Memrefs and NumPy arrays and vice-versa.
6
7import numpy as np
8import ctypes
9
10
11class C128(ctypes.Structure):
12  """A ctype representation for MLIR's Double Complex."""
13  _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)]
14
15
16class C64(ctypes.Structure):
17  """A ctype representation for MLIR's Float Complex."""
18  _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)]
19
20
21class F16(ctypes.Structure):
22  """A ctype representation for MLIR's Float16."""
23  _fields_ = [("f16", ctypes.c_int16)]
24
25
26def as_ctype(dtp):
27  """Converts dtype to ctype."""
28  if dtp is np.dtype(np.complex128):
29    return C128
30  if dtp is np.dtype(np.complex64):
31    return C64
32  if dtp is np.dtype(np.float16):
33    return F16
34  return np.ctypeslib.as_ctypes_type(dtp)
35
36
37def to_numpy(array):
38  """Converts ctypes array back to numpy dtype array."""
39  if array.dtype == C128:
40    return array.view("complex128")
41  if array.dtype == C64:
42    return array.view("complex64")
43  if array.dtype == F16:
44    return array.view("float16")
45  return array
46
47
48def make_nd_memref_descriptor(rank, dtype):
49
50  class MemRefDescriptor(ctypes.Structure):
51    """Builds an empty descriptor for the given rank/dtype, where rank>0."""
52
53    _fields_ = [
54        ("allocated", ctypes.c_longlong),
55        ("aligned", ctypes.POINTER(dtype)),
56        ("offset", ctypes.c_longlong),
57        ("shape", ctypes.c_longlong * rank),
58        ("strides", ctypes.c_longlong * rank),
59    ]
60
61  return MemRefDescriptor
62
63
64def make_zero_d_memref_descriptor(dtype):
65
66  class MemRefDescriptor(ctypes.Structure):
67    """Builds an empty descriptor for the given dtype, where rank=0."""
68
69    _fields_ = [
70        ("allocated", ctypes.c_longlong),
71        ("aligned", ctypes.POINTER(dtype)),
72        ("offset", ctypes.c_longlong),
73    ]
74
75  return MemRefDescriptor
76
77
78class UnrankedMemRefDescriptor(ctypes.Structure):
79  """Creates a ctype struct for memref descriptor"""
80  _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)]
81
82
83def get_ranked_memref_descriptor(nparray):
84  """Returns a ranked memref descriptor for the given numpy array."""
85  ctp = as_ctype(nparray.dtype)
86  if nparray.ndim == 0:
87    x = make_zero_d_memref_descriptor(ctp)()
88    x.allocated = nparray.ctypes.data
89    x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
90    x.offset = ctypes.c_longlong(0)
91    return x
92
93  x = make_nd_memref_descriptor(nparray.ndim, ctp)()
94  x.allocated = nparray.ctypes.data
95  x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp))
96  x.offset = ctypes.c_longlong(0)
97  x.shape = nparray.ctypes.shape
98
99  # Numpy uses byte quantities to express strides, MLIR OTOH uses the
100  # torch abstraction which specifies strides in terms of elements.
101  strides_ctype_t = ctypes.c_longlong * nparray.ndim
102  x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides])
103  return x
104
105
106def get_unranked_memref_descriptor(nparray):
107  """Returns a generic/unranked memref descriptor for the given numpy array."""
108  d = UnrankedMemRefDescriptor()
109  d.rank = nparray.ndim
110  x = get_ranked_memref_descriptor(nparray)
111  d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p)
112  return d
113
114
115def unranked_memref_to_numpy(unranked_memref, np_dtype):
116  """Converts unranked memrefs to numpy arrays."""
117  ctp = as_ctype(np_dtype)
118  descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp)
119  val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor))
120  np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape)
121  strided_arr = np.lib.stride_tricks.as_strided(
122      np_arr,
123      np.ctypeslib.as_array(val[0].shape),
124      np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize,
125  )
126  return to_numpy(strided_arr)
127
128
129def ranked_memref_to_numpy(ranked_memref):
130  """Converts ranked memrefs to numpy arrays."""
131  np_arr = np.ctypeslib.as_array(
132      ranked_memref[0].aligned, shape=ranked_memref[0].shape)
133  strided_arr = np.lib.stride_tricks.as_strided(
134      np_arr,
135      np.ctypeslib.as_array(ranked_memref[0].shape),
136      np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize,
137  )
138  return to_numpy(strided_arr)
139