1 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Interfaces.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Support.h" 13 #include "mlir/CAPI/Wrap.h" 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "llvm/ADT/ScopeExit.h" 16 17 using namespace mlir; 18 19 bool mlirOperationImplementsInterface(MlirOperation operation, 20 MlirTypeID interfaceTypeID) { 21 Optional<RegisteredOperationName> info = 22 unwrap(operation)->getRegisteredInfo(); 23 return info && info->hasInterface(unwrap(interfaceTypeID)); 24 } 25 26 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, 27 MlirContext context, 28 MlirTypeID interfaceTypeID) { 29 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup( 30 StringRef(operationName.data, operationName.length), unwrap(context)); 31 return info && info->hasInterface(unwrap(interfaceTypeID)); 32 } 33 34 MlirTypeID mlirInferTypeOpInterfaceTypeID() { 35 return wrap(InferTypeOpInterface::getInterfaceID()); 36 } 37 38 MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( 39 MlirStringRef opName, MlirContext context, MlirLocation location, 40 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, 41 intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, 42 void *userData) { 43 StringRef name(opName.data, opName.length); 44 Optional<RegisteredOperationName> info = 45 RegisteredOperationName::lookup(name, unwrap(context)); 46 if (!info) 47 return mlirLogicalResultFailure(); 48 49 llvm::Optional<Location> maybeLocation = llvm::None; 50 if (!mlirLocationIsNull(location)) 51 maybeLocation = unwrap(location); 52 SmallVector<Value> unwrappedOperands; 53 (void)unwrapList(nOperands, operands, unwrappedOperands); 54 DictionaryAttr attributeDict; 55 if (!mlirAttributeIsNull(attributes)) 56 attributeDict = unwrap(attributes).cast<DictionaryAttr>(); 57 58 // Create a vector of unique pointers to regions and make sure they are not 59 // deleted when exiting the scope. This is a hack caused by C++ API expecting 60 // an list of unique pointers to regions (without ownership transfer 61 // semantics) and C API making ownership transfer explicit. 62 SmallVector<std::unique_ptr<Region>> unwrappedRegions; 63 unwrappedRegions.reserve(nRegions); 64 for (intptr_t i = 0; i < nRegions; ++i) 65 unwrappedRegions.emplace_back(unwrap(*(regions + i))); 66 auto cleaner = llvm::make_scope_exit([&]() { 67 for (auto ®ion : unwrappedRegions) 68 region.release(); 69 }); 70 71 SmallVector<Type> inferredTypes; 72 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes( 73 unwrap(context), maybeLocation, unwrappedOperands, attributeDict, 74 unwrappedRegions, inferredTypes))) 75 return mlirLogicalResultFailure(); 76 77 SmallVector<MlirType> wrappedInferredTypes; 78 wrappedInferredTypes.reserve(inferredTypes.size()); 79 for (Type t : inferredTypes) 80 wrappedInferredTypes.push_back(wrap(t)); 81 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); 82 return mlirLogicalResultSuccess(); 83 } 84