1 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/Interfaces.h"
10 
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Wrap.h"
13 #include "mlir/Interfaces/InferTypeOpInterface.h"
14 #include "llvm/ADT/ScopeExit.h"
15 
16 using namespace mlir;
17 
18 bool mlirOperationImplementsInterface(MlirOperation operation,
19                                       MlirTypeID interfaceTypeID) {
20   Optional<RegisteredOperationName> info =
21       unwrap(operation)->getRegisteredInfo();
22   return info && info->hasInterface(unwrap(interfaceTypeID));
23 }
24 
25 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
26                                             MlirContext context,
27                                             MlirTypeID interfaceTypeID) {
28   Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
29       StringRef(operationName.data, operationName.length), unwrap(context));
30   return info && info->hasInterface(unwrap(interfaceTypeID));
31 }
32 
33 MlirTypeID mlirInferTypeOpInterfaceTypeID() {
34   return wrap(InferTypeOpInterface::getInterfaceID());
35 }
36 
37 MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
38     MlirStringRef opName, MlirContext context, MlirLocation location,
39     intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
40     intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
41     void *userData) {
42   StringRef name(opName.data, opName.length);
43   Optional<RegisteredOperationName> info =
44       RegisteredOperationName::lookup(name, unwrap(context));
45   if (!info)
46     return mlirLogicalResultFailure();
47 
48   llvm::Optional<Location> maybeLocation = llvm::None;
49   if (!mlirLocationIsNull(location))
50     maybeLocation = unwrap(location);
51   SmallVector<Value> unwrappedOperands;
52   (void)unwrapList(nOperands, operands, unwrappedOperands);
53   DictionaryAttr attributeDict;
54   if (!mlirAttributeIsNull(attributes))
55     attributeDict = unwrap(attributes).cast<DictionaryAttr>();
56 
57   // Create a vector of unique pointers to regions and make sure they are not
58   // deleted when exiting the scope. This is a hack caused by C++ API expecting
59   // an list of unique pointers to regions (without ownership transfer
60   // semantics) and C API making ownership transfer explicit.
61   SmallVector<std::unique_ptr<Region>> unwrappedRegions;
62   unwrappedRegions.reserve(nRegions);
63   for (intptr_t i = 0; i < nRegions; ++i)
64     unwrappedRegions.emplace_back(unwrap(*(regions + i)));
65   auto cleaner = llvm::make_scope_exit([&]() {
66     for (auto &region : unwrappedRegions)
67       region.release();
68   });
69 
70   SmallVector<Type> inferredTypes;
71   if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
72           unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
73           unwrappedRegions, inferredTypes)))
74     return mlirLogicalResultFailure();
75 
76   SmallVector<MlirType> wrappedInferredTypes;
77   wrappedInferredTypes.reserve(inferredTypes.size());
78   for (Type t : inferredTypes)
79     wrappedInferredTypes.push_back(wrap(t));
80   callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
81   return mlirLogicalResultSuccess();
82 }
83