1 //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "IRModule.h"
12 #include "mlir-c/BuiltinAttributes.h"
13 #include "mlir-c/Interfaces.h"
14
15 namespace py = pybind11;
16
17 namespace mlir {
18 namespace python {
19
20 constexpr static const char *constructorDoc =
21 R"(Creates an interface from a given operation/opview object or from a
22 subclass of OpView. Raises ValueError if the operation does not implement the
23 interface.)";
24
25 constexpr static const char *operationDoc =
26 R"(Returns an Operation for which the interface was constructed.)";
27
28 constexpr static const char *opviewDoc =
29 R"(Returns an OpView subclass _instance_ for which the interface was
30 constructed)";
31
32 constexpr static const char *inferReturnTypesDoc =
33 R"(Given the arguments required to build an operation, attempts to infer
34 its return types. Raises ValueError on failure.)";
35
36 /// CRTP base class for Python classes representing MLIR Op interfaces.
37 /// Interface hierarchies are flat so no base class is expected here. The
38 /// derived class is expected to define the following static fields:
39 /// - `const char *pyClassName` - the name of the Python class to create;
40 /// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
41 /// of the interface.
42 /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
43 /// interface-specific methods.
44 ///
45 /// An interface class may be constructed from either an Operation/OpView object
46 /// or from a subclass of OpView. In the latter case, only the static interface
47 /// methods are available, similarly to calling ConcereteOp::staticMethod on the
48 /// C++ side. Implementations of concrete interfaces can use the `isStatic`
49 /// method to check whether the interface object was constructed from a class or
50 /// an operation/opview instance. The `getOpName` always succeeds and returns a
51 /// canonical name of the operation suitable for lookups.
52 template <typename ConcreteIface>
53 class PyConcreteOpInterface {
54 protected:
55 using ClassTy = py::class_<ConcreteIface>;
56 using GetTypeIDFunctionTy = MlirTypeID (*)();
57
58 public:
59 /// Constructs an interface instance from an object that is either an
60 /// operation or a subclass of OpView. In the latter case, only the static
61 /// methods of the interface are accessible to the caller.
PyConcreteOpInterface(py::object object,DefaultingPyMlirContext context)62 PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
63 : obj(std::move(object)) {
64 try {
65 operation = &py::cast<PyOperation &>(obj);
66 } catch (py::cast_error &) {
67 // Do nothing.
68 }
69
70 try {
71 operation = &py::cast<PyOpView &>(obj).getOperation();
72 } catch (py::cast_error &) {
73 // Do nothing.
74 }
75
76 if (operation != nullptr) {
77 if (!mlirOperationImplementsInterface(*operation,
78 ConcreteIface::getInterfaceID())) {
79 std::string msg = "the operation does not implement ";
80 throw py::value_error(msg + ConcreteIface::pyClassName);
81 }
82
83 MlirIdentifier identifier = mlirOperationGetName(*operation);
84 MlirStringRef stringRef = mlirIdentifierStr(identifier);
85 opName = std::string(stringRef.data, stringRef.length);
86 } else {
87 try {
88 opName = obj.attr("OPERATION_NAME").template cast<std::string>();
89 } catch (py::cast_error &) {
90 throw py::type_error(
91 "Op interface does not refer to an operation or OpView class");
92 }
93
94 if (!mlirOperationImplementsInterfaceStatic(
95 mlirStringRefCreate(opName.data(), opName.length()),
96 context.resolve().get(), ConcreteIface::getInterfaceID())) {
97 std::string msg = "the operation does not implement ";
98 throw py::value_error(msg + ConcreteIface::pyClassName);
99 }
100 }
101 }
102
103 /// Creates the Python bindings for this class in the given module.
bind(py::module & m)104 static void bind(py::module &m) {
105 py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
106 py::module_local());
107 cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
108 py::arg("context") = py::none(), constructorDoc)
109 .def_property_readonly("operation",
110 &PyConcreteOpInterface::getOperationObject,
111 operationDoc)
112 .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
113 opviewDoc);
114 ConcreteIface::bindDerived(cls);
115 }
116
117 /// Hook for derived classes to add class-specific bindings.
bindDerived(ClassTy & cls)118 static void bindDerived(ClassTy &cls) {}
119
120 /// Returns `true` if this object was constructed from a subclass of OpView
121 /// rather than from an operation instance.
isStatic()122 bool isStatic() { return operation == nullptr; }
123
124 /// Returns the operation instance from which this object was constructed.
125 /// Throws a type error if this object was constructed from a subclass of
126 /// OpView.
getOperationObject()127 py::object getOperationObject() {
128 if (operation == nullptr) {
129 throw py::type_error("Cannot get an operation from a static interface");
130 }
131
132 return operation->getRef().releaseObject();
133 }
134
135 /// Returns the opview of the operation instance from which this object was
136 /// constructed. Throws a type error if this object was constructed form a
137 /// subclass of OpView.
getOpView()138 py::object getOpView() {
139 if (operation == nullptr) {
140 throw py::type_error("Cannot get an opview from a static interface");
141 }
142
143 return operation->createOpView();
144 }
145
146 /// Returns the canonical name of the operation this interface is constructed
147 /// from.
getOpName()148 const std::string &getOpName() { return opName; }
149
150 private:
151 PyOperation *operation = nullptr;
152 std::string opName;
153 py::object obj;
154 };
155
156 /// Python wrapper for InterTypeOpInterface. This interface has only static
157 /// methods.
158 class PyInferTypeOpInterface
159 : public PyConcreteOpInterface<PyInferTypeOpInterface> {
160 public:
161 using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
162
163 constexpr static const char *pyClassName = "InferTypeOpInterface";
164 constexpr static GetTypeIDFunctionTy getInterfaceID =
165 &mlirInferTypeOpInterfaceTypeID;
166
167 /// C-style user-data structure for type appending callback.
168 struct AppendResultsCallbackData {
169 std::vector<PyType> &inferredTypes;
170 PyMlirContext &pyMlirContext;
171 };
172
173 /// Appends the types provided as the two first arguments to the user-data
174 /// structure (expects AppendResultsCallbackData).
appendResultsCallback(intptr_t nTypes,MlirType * types,void * userData)175 static void appendResultsCallback(intptr_t nTypes, MlirType *types,
176 void *userData) {
177 auto *data = static_cast<AppendResultsCallbackData *>(userData);
178 data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
179 for (intptr_t i = 0; i < nTypes; ++i) {
180 data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
181 }
182 }
183
184 /// Given the arguments required to build an operation, attempts to infer its
185 /// return types. Throws value_error on faliure.
186 std::vector<PyType>
inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,llvm::Optional<PyAttribute> attributes,llvm::Optional<std::vector<PyRegion>> regions,DefaultingPyMlirContext context,DefaultingPyLocation location)187 inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
188 llvm::Optional<PyAttribute> attributes,
189 llvm::Optional<std::vector<PyRegion>> regions,
190 DefaultingPyMlirContext context,
191 DefaultingPyLocation location) {
192 llvm::SmallVector<MlirValue> mlirOperands;
193 llvm::SmallVector<MlirRegion> mlirRegions;
194
195 if (operands) {
196 mlirOperands.reserve(operands->size());
197 for (PyValue &value : *operands) {
198 mlirOperands.push_back(value);
199 }
200 }
201
202 if (regions) {
203 mlirRegions.reserve(regions->size());
204 for (PyRegion ®ion : *regions) {
205 mlirRegions.push_back(region);
206 }
207 }
208
209 std::vector<PyType> inferredTypes;
210 PyMlirContext &pyContext = context.resolve();
211 AppendResultsCallbackData data{inferredTypes, pyContext};
212 MlirStringRef opNameRef =
213 mlirStringRefCreate(getOpName().data(), getOpName().length());
214 MlirAttribute attributeDict =
215 attributes ? attributes->get() : mlirAttributeGetNull();
216
217 MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
218 opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
219 mlirOperands.data(), attributeDict, mlirRegions.size(),
220 mlirRegions.data(), &appendResultsCallback, &data);
221
222 if (mlirLogicalResultIsFailure(result)) {
223 throw py::value_error("Failed to infer result types");
224 }
225
226 return inferredTypes;
227 }
228
bindDerived(ClassTy & cls)229 static void bindDerived(ClassTy &cls) {
230 cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
231 py::arg("operands") = py::none(),
232 py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
233 py::arg("context") = py::none(), py::arg("loc") = py::none(),
234 inferReturnTypesDoc);
235 }
236 };
237
populateIRInterfaces(py::module & m)238 void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
239
240 } // namespace python
241 } // namespace mlir
242