114c92070SAlex Zinenko //===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
214c92070SAlex Zinenko //
314c92070SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
414c92070SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
514c92070SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
614c92070SAlex Zinenko //
714c92070SAlex Zinenko //===----------------------------------------------------------------------===//
814c92070SAlex Zinenko 
91fc096afSMehdi Amini #include <utility>
101fc096afSMehdi Amini 
1114c92070SAlex Zinenko #include "IRModule.h"
1214c92070SAlex Zinenko #include "mlir-c/BuiltinAttributes.h"
1314c92070SAlex Zinenko #include "mlir-c/Interfaces.h"
1414c92070SAlex Zinenko 
1514c92070SAlex Zinenko namespace py = pybind11;
1614c92070SAlex Zinenko 
1714c92070SAlex Zinenko namespace mlir {
1814c92070SAlex Zinenko namespace python {
1914c92070SAlex Zinenko 
2014c92070SAlex Zinenko constexpr static const char *constructorDoc =
2114c92070SAlex Zinenko     R"(Creates an interface from a given operation/opview object or from a
2214c92070SAlex Zinenko subclass of OpView. Raises ValueError if the operation does not implement the
2314c92070SAlex Zinenko interface.)";
2414c92070SAlex Zinenko 
2514c92070SAlex Zinenko constexpr static const char *operationDoc =
2614c92070SAlex Zinenko     R"(Returns an Operation for which the interface was constructed.)";
2714c92070SAlex Zinenko 
2814c92070SAlex Zinenko constexpr static const char *opviewDoc =
2914c92070SAlex Zinenko     R"(Returns an OpView subclass _instance_ for which the interface was
3014c92070SAlex Zinenko constructed)";
3114c92070SAlex Zinenko 
3214c92070SAlex Zinenko constexpr static const char *inferReturnTypesDoc =
3314c92070SAlex Zinenko     R"(Given the arguments required to build an operation, attempts to infer
3414c92070SAlex Zinenko its return types. Raises ValueError on failure.)";
3514c92070SAlex Zinenko 
3614c92070SAlex Zinenko /// CRTP base class for Python classes representing MLIR Op interfaces.
3714c92070SAlex Zinenko /// Interface hierarchies are flat so no base class is expected here. The
3814c92070SAlex Zinenko /// derived class is expected to define the following static fields:
3914c92070SAlex Zinenko ///  - `const char *pyClassName` - the name of the Python class to create;
4014c92070SAlex Zinenko ///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
4114c92070SAlex Zinenko ///    of the interface.
4214c92070SAlex Zinenko /// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
4314c92070SAlex Zinenko /// interface-specific methods.
4414c92070SAlex Zinenko ///
4514c92070SAlex Zinenko /// An interface class may be constructed from either an Operation/OpView object
4614c92070SAlex Zinenko /// or from a subclass of OpView. In the latter case, only the static interface
4714c92070SAlex Zinenko /// methods are available, similarly to calling ConcereteOp::staticMethod on the
4814c92070SAlex Zinenko /// C++ side. Implementations of concrete interfaces can use the `isStatic`
4914c92070SAlex Zinenko /// method to check whether the interface object was constructed from a class or
5014c92070SAlex Zinenko /// an operation/opview instance. The `getOpName` always succeeds and returns a
5114c92070SAlex Zinenko /// canonical name of the operation suitable for lookups.
5214c92070SAlex Zinenko template <typename ConcreteIface>
5314c92070SAlex Zinenko class PyConcreteOpInterface {
5414c92070SAlex Zinenko protected:
5514c92070SAlex Zinenko   using ClassTy = py::class_<ConcreteIface>;
5614c92070SAlex Zinenko   using GetTypeIDFunctionTy = MlirTypeID (*)();
5714c92070SAlex Zinenko 
5814c92070SAlex Zinenko public:
5914c92070SAlex Zinenko   /// Constructs an interface instance from an object that is either an
6014c92070SAlex Zinenko   /// operation or a subclass of OpView. In the latter case, only the static
6114c92070SAlex Zinenko   /// methods of the interface are accessible to the caller.
PyConcreteOpInterface(py::object object,DefaultingPyMlirContext context)6214c92070SAlex Zinenko   PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
631fc096afSMehdi Amini       : obj(std::move(object)) {
6414c92070SAlex Zinenko     try {
6514c92070SAlex Zinenko       operation = &py::cast<PyOperation &>(obj);
66*8fb1bef6SNathaniel McVicar     } catch (py::cast_error &) {
6714c92070SAlex Zinenko       // Do nothing.
6814c92070SAlex Zinenko     }
6914c92070SAlex Zinenko 
7014c92070SAlex Zinenko     try {
7114c92070SAlex Zinenko       operation = &py::cast<PyOpView &>(obj).getOperation();
72*8fb1bef6SNathaniel McVicar     } catch (py::cast_error &) {
7314c92070SAlex Zinenko       // Do nothing.
7414c92070SAlex Zinenko     }
7514c92070SAlex Zinenko 
7614c92070SAlex Zinenko     if (operation != nullptr) {
7714c92070SAlex Zinenko       if (!mlirOperationImplementsInterface(*operation,
7814c92070SAlex Zinenko                                             ConcreteIface::getInterfaceID())) {
7914c92070SAlex Zinenko         std::string msg = "the operation does not implement ";
8014c92070SAlex Zinenko         throw py::value_error(msg + ConcreteIface::pyClassName);
8114c92070SAlex Zinenko       }
8214c92070SAlex Zinenko 
8314c92070SAlex Zinenko       MlirIdentifier identifier = mlirOperationGetName(*operation);
8414c92070SAlex Zinenko       MlirStringRef stringRef = mlirIdentifierStr(identifier);
8514c92070SAlex Zinenko       opName = std::string(stringRef.data, stringRef.length);
8614c92070SAlex Zinenko     } else {
8714c92070SAlex Zinenko       try {
8814c92070SAlex Zinenko         opName = obj.attr("OPERATION_NAME").template cast<std::string>();
89*8fb1bef6SNathaniel McVicar       } catch (py::cast_error &) {
9014c92070SAlex Zinenko         throw py::type_error(
9114c92070SAlex Zinenko             "Op interface does not refer to an operation or OpView class");
9214c92070SAlex Zinenko       }
9314c92070SAlex Zinenko 
9414c92070SAlex Zinenko       if (!mlirOperationImplementsInterfaceStatic(
9514c92070SAlex Zinenko               mlirStringRefCreate(opName.data(), opName.length()),
9614c92070SAlex Zinenko               context.resolve().get(), ConcreteIface::getInterfaceID())) {
9714c92070SAlex Zinenko         std::string msg = "the operation does not implement ";
9814c92070SAlex Zinenko         throw py::value_error(msg + ConcreteIface::pyClassName);
9914c92070SAlex Zinenko       }
10014c92070SAlex Zinenko     }
10114c92070SAlex Zinenko   }
10214c92070SAlex Zinenko 
10314c92070SAlex Zinenko   /// Creates the Python bindings for this class in the given module.
bind(py::module & m)10414c92070SAlex Zinenko   static void bind(py::module &m) {
10514c92070SAlex Zinenko     py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
10614c92070SAlex Zinenko                                   py::module_local());
10714c92070SAlex Zinenko     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
10814c92070SAlex Zinenko             py::arg("context") = py::none(), constructorDoc)
10914c92070SAlex Zinenko         .def_property_readonly("operation",
11014c92070SAlex Zinenko                                &PyConcreteOpInterface::getOperationObject,
11114c92070SAlex Zinenko                                operationDoc)
11214c92070SAlex Zinenko         .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
11314c92070SAlex Zinenko                                opviewDoc);
11414c92070SAlex Zinenko     ConcreteIface::bindDerived(cls);
11514c92070SAlex Zinenko   }
11614c92070SAlex Zinenko 
11714c92070SAlex Zinenko   /// Hook for derived classes to add class-specific bindings.
bindDerived(ClassTy & cls)11814c92070SAlex Zinenko   static void bindDerived(ClassTy &cls) {}
11914c92070SAlex Zinenko 
12014c92070SAlex Zinenko   /// Returns `true` if this object was constructed from a subclass of OpView
12114c92070SAlex Zinenko   /// rather than from an operation instance.
isStatic()12214c92070SAlex Zinenko   bool isStatic() { return operation == nullptr; }
12314c92070SAlex Zinenko 
12414c92070SAlex Zinenko   /// Returns the operation instance from which this object was constructed.
12514c92070SAlex Zinenko   /// Throws a type error if this object was constructed from a subclass of
12614c92070SAlex Zinenko   /// OpView.
getOperationObject()12714c92070SAlex Zinenko   py::object getOperationObject() {
12814c92070SAlex Zinenko     if (operation == nullptr) {
12914c92070SAlex Zinenko       throw py::type_error("Cannot get an operation from a static interface");
13014c92070SAlex Zinenko     }
13114c92070SAlex Zinenko 
13214c92070SAlex Zinenko     return operation->getRef().releaseObject();
13314c92070SAlex Zinenko   }
13414c92070SAlex Zinenko 
13514c92070SAlex Zinenko   /// Returns the opview of the operation instance from which this object was
13614c92070SAlex Zinenko   /// constructed. Throws a type error if this object was constructed form a
13714c92070SAlex Zinenko   /// subclass of OpView.
getOpView()13814c92070SAlex Zinenko   py::object getOpView() {
13914c92070SAlex Zinenko     if (operation == nullptr) {
14014c92070SAlex Zinenko       throw py::type_error("Cannot get an opview from a static interface");
14114c92070SAlex Zinenko     }
14214c92070SAlex Zinenko 
14314c92070SAlex Zinenko     return operation->createOpView();
14414c92070SAlex Zinenko   }
14514c92070SAlex Zinenko 
14614c92070SAlex Zinenko   /// Returns the canonical name of the operation this interface is constructed
14714c92070SAlex Zinenko   /// from.
getOpName()14814c92070SAlex Zinenko   const std::string &getOpName() { return opName; }
14914c92070SAlex Zinenko 
15014c92070SAlex Zinenko private:
15114c92070SAlex Zinenko   PyOperation *operation = nullptr;
15214c92070SAlex Zinenko   std::string opName;
15314c92070SAlex Zinenko   py::object obj;
15414c92070SAlex Zinenko };
15514c92070SAlex Zinenko 
15614c92070SAlex Zinenko /// Python wrapper for InterTypeOpInterface. This interface has only static
15714c92070SAlex Zinenko /// methods.
15814c92070SAlex Zinenko class PyInferTypeOpInterface
15914c92070SAlex Zinenko     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
16014c92070SAlex Zinenko public:
16114c92070SAlex Zinenko   using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
16214c92070SAlex Zinenko 
16314c92070SAlex Zinenko   constexpr static const char *pyClassName = "InferTypeOpInterface";
16414c92070SAlex Zinenko   constexpr static GetTypeIDFunctionTy getInterfaceID =
16514c92070SAlex Zinenko       &mlirInferTypeOpInterfaceTypeID;
16614c92070SAlex Zinenko 
16714c92070SAlex Zinenko   /// C-style user-data structure for type appending callback.
16814c92070SAlex Zinenko   struct AppendResultsCallbackData {
16914c92070SAlex Zinenko     std::vector<PyType> &inferredTypes;
17014c92070SAlex Zinenko     PyMlirContext &pyMlirContext;
17114c92070SAlex Zinenko   };
17214c92070SAlex Zinenko 
17314c92070SAlex Zinenko   /// Appends the types provided as the two first arguments to the user-data
17414c92070SAlex Zinenko   /// structure (expects AppendResultsCallbackData).
appendResultsCallback(intptr_t nTypes,MlirType * types,void * userData)17514c92070SAlex Zinenko   static void appendResultsCallback(intptr_t nTypes, MlirType *types,
17614c92070SAlex Zinenko                                     void *userData) {
17714c92070SAlex Zinenko     auto *data = static_cast<AppendResultsCallbackData *>(userData);
17814c92070SAlex Zinenko     data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
17914c92070SAlex Zinenko     for (intptr_t i = 0; i < nTypes; ++i) {
180e5639b3fSMehdi Amini       data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]);
18114c92070SAlex Zinenko     }
18214c92070SAlex Zinenko   }
18314c92070SAlex Zinenko 
18414c92070SAlex Zinenko   /// Given the arguments required to build an operation, attempts to infer its
18514c92070SAlex Zinenko   /// return types. Throws value_error on faliure.
18614c92070SAlex Zinenko   std::vector<PyType>
inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,llvm::Optional<PyAttribute> attributes,llvm::Optional<std::vector<PyRegion>> regions,DefaultingPyMlirContext context,DefaultingPyLocation location)18714c92070SAlex Zinenko   inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
18814c92070SAlex Zinenko                    llvm::Optional<PyAttribute> attributes,
18914c92070SAlex Zinenko                    llvm::Optional<std::vector<PyRegion>> regions,
19014c92070SAlex Zinenko                    DefaultingPyMlirContext context,
19114c92070SAlex Zinenko                    DefaultingPyLocation location) {
19214c92070SAlex Zinenko     llvm::SmallVector<MlirValue> mlirOperands;
19314c92070SAlex Zinenko     llvm::SmallVector<MlirRegion> mlirRegions;
19414c92070SAlex Zinenko 
19514c92070SAlex Zinenko     if (operands) {
19614c92070SAlex Zinenko       mlirOperands.reserve(operands->size());
19714c92070SAlex Zinenko       for (PyValue &value : *operands) {
19814c92070SAlex Zinenko         mlirOperands.push_back(value);
19914c92070SAlex Zinenko       }
20014c92070SAlex Zinenko     }
20114c92070SAlex Zinenko 
20214c92070SAlex Zinenko     if (regions) {
20314c92070SAlex Zinenko       mlirRegions.reserve(regions->size());
20414c92070SAlex Zinenko       for (PyRegion &region : *regions) {
20514c92070SAlex Zinenko         mlirRegions.push_back(region);
20614c92070SAlex Zinenko       }
20714c92070SAlex Zinenko     }
20814c92070SAlex Zinenko 
20914c92070SAlex Zinenko     std::vector<PyType> inferredTypes;
21014c92070SAlex Zinenko     PyMlirContext &pyContext = context.resolve();
21114c92070SAlex Zinenko     AppendResultsCallbackData data{inferredTypes, pyContext};
21214c92070SAlex Zinenko     MlirStringRef opNameRef =
21314c92070SAlex Zinenko         mlirStringRefCreate(getOpName().data(), getOpName().length());
21414c92070SAlex Zinenko     MlirAttribute attributeDict =
21514c92070SAlex Zinenko         attributes ? attributes->get() : mlirAttributeGetNull();
21614c92070SAlex Zinenko 
21714c92070SAlex Zinenko     MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
21814c92070SAlex Zinenko         opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
21914c92070SAlex Zinenko         mlirOperands.data(), attributeDict, mlirRegions.size(),
22014c92070SAlex Zinenko         mlirRegions.data(), &appendResultsCallback, &data);
22114c92070SAlex Zinenko 
22214c92070SAlex Zinenko     if (mlirLogicalResultIsFailure(result)) {
22314c92070SAlex Zinenko       throw py::value_error("Failed to infer result types");
22414c92070SAlex Zinenko     }
22514c92070SAlex Zinenko 
22614c92070SAlex Zinenko     return inferredTypes;
22714c92070SAlex Zinenko   }
22814c92070SAlex Zinenko 
bindDerived(ClassTy & cls)22914c92070SAlex Zinenko   static void bindDerived(ClassTy &cls) {
23014c92070SAlex Zinenko     cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
23114c92070SAlex Zinenko             py::arg("operands") = py::none(),
23214c92070SAlex Zinenko             py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
23314c92070SAlex Zinenko             py::arg("context") = py::none(), py::arg("loc") = py::none(),
23414c92070SAlex Zinenko             inferReturnTypesDoc);
23514c92070SAlex Zinenko   }
23614c92070SAlex Zinenko };
23714c92070SAlex Zinenko 
populateIRInterfaces(py::module & m)23814c92070SAlex Zinenko void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
23914c92070SAlex Zinenko 
24014c92070SAlex Zinenko } // namespace python
24114c92070SAlex Zinenko } // namespace mlir
242