1 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/Interfaces.h"
10
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/CAPI/Wrap.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "llvm/ADT/ScopeExit.h"
16
17 using namespace mlir;
18
mlirOperationImplementsInterface(MlirOperation operation,MlirTypeID interfaceTypeID)19 bool mlirOperationImplementsInterface(MlirOperation operation,
20 MlirTypeID interfaceTypeID) {
21 Optional<RegisteredOperationName> info =
22 unwrap(operation)->getRegisteredInfo();
23 return info && info->hasInterface(unwrap(interfaceTypeID));
24 }
25
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,MlirContext context,MlirTypeID interfaceTypeID)26 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
27 MlirContext context,
28 MlirTypeID interfaceTypeID) {
29 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
30 StringRef(operationName.data, operationName.length), unwrap(context));
31 return info && info->hasInterface(unwrap(interfaceTypeID));
32 }
33
mlirInferTypeOpInterfaceTypeID()34 MlirTypeID mlirInferTypeOpInterfaceTypeID() {
35 return wrap(InferTypeOpInterface::getInterfaceID());
36 }
37
mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,intptr_t nRegions,MlirRegion * regions,MlirTypesCallback callback,void * userData)38 MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
39 MlirStringRef opName, MlirContext context, MlirLocation location,
40 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
41 intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
42 void *userData) {
43 StringRef name(opName.data, opName.length);
44 Optional<RegisteredOperationName> info =
45 RegisteredOperationName::lookup(name, unwrap(context));
46 if (!info)
47 return mlirLogicalResultFailure();
48
49 llvm::Optional<Location> maybeLocation = llvm::None;
50 if (!mlirLocationIsNull(location))
51 maybeLocation = unwrap(location);
52 SmallVector<Value> unwrappedOperands;
53 (void)unwrapList(nOperands, operands, unwrappedOperands);
54 DictionaryAttr attributeDict;
55 if (!mlirAttributeIsNull(attributes))
56 attributeDict = unwrap(attributes).cast<DictionaryAttr>();
57
58 // Create a vector of unique pointers to regions and make sure they are not
59 // deleted when exiting the scope. This is a hack caused by C++ API expecting
60 // an list of unique pointers to regions (without ownership transfer
61 // semantics) and C API making ownership transfer explicit.
62 SmallVector<std::unique_ptr<Region>> unwrappedRegions;
63 unwrappedRegions.reserve(nRegions);
64 for (intptr_t i = 0; i < nRegions; ++i)
65 unwrappedRegions.emplace_back(unwrap(*(regions + i)));
66 auto cleaner = llvm::make_scope_exit([&]() {
67 for (auto ®ion : unwrappedRegions)
68 region.release();
69 });
70
71 SmallVector<Type> inferredTypes;
72 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
73 unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
74 unwrappedRegions, inferredTypes)))
75 return mlirLogicalResultFailure();
76
77 SmallVector<MlirType> wrappedInferredTypes;
78 wrappedInferredTypes.reserve(inferredTypes.size());
79 for (Type t : inferredTypes)
80 wrappedInferredTypes.push_back(wrap(t));
81 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
82 return mlirLogicalResultSuccess();
83 }
84