114c92070SAlex Zinenko //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
214c92070SAlex Zinenko //
314c92070SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
414c92070SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
514c92070SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
614c92070SAlex Zinenko //
714c92070SAlex Zinenko //===----------------------------------------------------------------------===//
814c92070SAlex Zinenko 
914c92070SAlex Zinenko #include "mlir-c/Interfaces.h"
1014c92070SAlex Zinenko 
1114c92070SAlex Zinenko #include "mlir/CAPI/IR.h"
12*4ae24d9fSBenjamin Kramer #include "mlir/CAPI/Support.h"
1314c92070SAlex Zinenko #include "mlir/CAPI/Wrap.h"
1414c92070SAlex Zinenko #include "mlir/Interfaces/InferTypeOpInterface.h"
1514c92070SAlex Zinenko #include "llvm/ADT/ScopeExit.h"
1614c92070SAlex Zinenko 
1714c92070SAlex Zinenko using namespace mlir;
1814c92070SAlex Zinenko 
mlirOperationImplementsInterface(MlirOperation operation,MlirTypeID interfaceTypeID)1914c92070SAlex Zinenko bool mlirOperationImplementsInterface(MlirOperation operation,
2014c92070SAlex Zinenko                                       MlirTypeID interfaceTypeID) {
21edc6c0ecSRiver Riddle   Optional<RegisteredOperationName> info =
22edc6c0ecSRiver Riddle       unwrap(operation)->getRegisteredInfo();
23edc6c0ecSRiver Riddle   return info && info->hasInterface(unwrap(interfaceTypeID));
2414c92070SAlex Zinenko }
2514c92070SAlex Zinenko 
mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,MlirContext context,MlirTypeID interfaceTypeID)2614c92070SAlex Zinenko bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
2714c92070SAlex Zinenko                                             MlirContext context,
2814c92070SAlex Zinenko                                             MlirTypeID interfaceTypeID) {
29edc6c0ecSRiver Riddle   Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
3014c92070SAlex Zinenko       StringRef(operationName.data, operationName.length), unwrap(context));
31edc6c0ecSRiver Riddle   return info && info->hasInterface(unwrap(interfaceTypeID));
3214c92070SAlex Zinenko }
3314c92070SAlex Zinenko 
mlirInferTypeOpInterfaceTypeID()3414c92070SAlex Zinenko MlirTypeID mlirInferTypeOpInterfaceTypeID() {
3514c92070SAlex Zinenko   return wrap(InferTypeOpInterface::getInterfaceID());
3614c92070SAlex Zinenko }
3714c92070SAlex Zinenko 
mlirInferTypeOpInterfaceInferReturnTypes(MlirStringRef opName,MlirContext context,MlirLocation location,intptr_t nOperands,MlirValue * operands,MlirAttribute attributes,intptr_t nRegions,MlirRegion * regions,MlirTypesCallback callback,void * userData)3814c92070SAlex Zinenko MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
3914c92070SAlex Zinenko     MlirStringRef opName, MlirContext context, MlirLocation location,
4014c92070SAlex Zinenko     intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
4114c92070SAlex Zinenko     intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
4214c92070SAlex Zinenko     void *userData) {
4314c92070SAlex Zinenko   StringRef name(opName.data, opName.length);
44edc6c0ecSRiver Riddle   Optional<RegisteredOperationName> info =
45edc6c0ecSRiver Riddle       RegisteredOperationName::lookup(name, unwrap(context));
46edc6c0ecSRiver Riddle   if (!info)
4714c92070SAlex Zinenko     return mlirLogicalResultFailure();
4814c92070SAlex Zinenko 
4914c92070SAlex Zinenko   llvm::Optional<Location> maybeLocation = llvm::None;
5014c92070SAlex Zinenko   if (!mlirLocationIsNull(location))
5114c92070SAlex Zinenko     maybeLocation = unwrap(location);
5214c92070SAlex Zinenko   SmallVector<Value> unwrappedOperands;
5314c92070SAlex Zinenko   (void)unwrapList(nOperands, operands, unwrappedOperands);
5414c92070SAlex Zinenko   DictionaryAttr attributeDict;
5514c92070SAlex Zinenko   if (!mlirAttributeIsNull(attributes))
5614c92070SAlex Zinenko     attributeDict = unwrap(attributes).cast<DictionaryAttr>();
5714c92070SAlex Zinenko 
5814c92070SAlex Zinenko   // Create a vector of unique pointers to regions and make sure they are not
5914c92070SAlex Zinenko   // deleted when exiting the scope. This is a hack caused by C++ API expecting
6014c92070SAlex Zinenko   // an list of unique pointers to regions (without ownership transfer
6114c92070SAlex Zinenko   // semantics) and C API making ownership transfer explicit.
6214c92070SAlex Zinenko   SmallVector<std::unique_ptr<Region>> unwrappedRegions;
6314c92070SAlex Zinenko   unwrappedRegions.reserve(nRegions);
6414c92070SAlex Zinenko   for (intptr_t i = 0; i < nRegions; ++i)
6514c92070SAlex Zinenko     unwrappedRegions.emplace_back(unwrap(*(regions + i)));
6614c92070SAlex Zinenko   auto cleaner = llvm::make_scope_exit([&]() {
6714c92070SAlex Zinenko     for (auto &region : unwrappedRegions)
6814c92070SAlex Zinenko       region.release();
6914c92070SAlex Zinenko   });
7014c92070SAlex Zinenko 
7114c92070SAlex Zinenko   SmallVector<Type> inferredTypes;
72edc6c0ecSRiver Riddle   if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
7314c92070SAlex Zinenko           unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
7414c92070SAlex Zinenko           unwrappedRegions, inferredTypes)))
7514c92070SAlex Zinenko     return mlirLogicalResultFailure();
7614c92070SAlex Zinenko 
7714c92070SAlex Zinenko   SmallVector<MlirType> wrappedInferredTypes;
7814c92070SAlex Zinenko   wrappedInferredTypes.reserve(inferredTypes.size());
7914c92070SAlex Zinenko   for (Type t : inferredTypes)
8014c92070SAlex Zinenko     wrappedInferredTypes.push_back(wrap(t));
8114c92070SAlex Zinenko   callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
8214c92070SAlex Zinenko   return mlirLogicalResultSuccess();
8314c92070SAlex Zinenko }
84