1 //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir-c/Interfaces.h" 10 11 #include "mlir/CAPI/IR.h" 12 #include "mlir/CAPI/Wrap.h" 13 #include "mlir/Interfaces/InferTypeOpInterface.h" 14 #include "llvm/ADT/ScopeExit.h" 15 16 using namespace mlir; 17 18 bool mlirOperationImplementsInterface(MlirOperation operation, 19 MlirTypeID interfaceTypeID) { 20 Optional<RegisteredOperationName> info = 21 unwrap(operation)->getRegisteredInfo(); 22 return info && info->hasInterface(unwrap(interfaceTypeID)); 23 } 24 25 bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, 26 MlirContext context, 27 MlirTypeID interfaceTypeID) { 28 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup( 29 StringRef(operationName.data, operationName.length), unwrap(context)); 30 return info && info->hasInterface(unwrap(interfaceTypeID)); 31 } 32 33 MlirTypeID mlirInferTypeOpInterfaceTypeID() { 34 return wrap(InferTypeOpInterface::getInterfaceID()); 35 } 36 37 MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( 38 MlirStringRef opName, MlirContext context, MlirLocation location, 39 intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, 40 intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, 41 void *userData) { 42 StringRef name(opName.data, opName.length); 43 Optional<RegisteredOperationName> info = 44 RegisteredOperationName::lookup(name, unwrap(context)); 45 if (!info) 46 return mlirLogicalResultFailure(); 47 48 llvm::Optional<Location> maybeLocation = llvm::None; 49 if (!mlirLocationIsNull(location)) 50 maybeLocation = unwrap(location); 51 SmallVector<Value> unwrappedOperands; 52 (void)unwrapList(nOperands, operands, unwrappedOperands); 53 DictionaryAttr attributeDict; 54 if (!mlirAttributeIsNull(attributes)) 55 attributeDict = unwrap(attributes).cast<DictionaryAttr>(); 56 57 // Create a vector of unique pointers to regions and make sure they are not 58 // deleted when exiting the scope. This is a hack caused by C++ API expecting 59 // an list of unique pointers to regions (without ownership transfer 60 // semantics) and C API making ownership transfer explicit. 61 SmallVector<std::unique_ptr<Region>> unwrappedRegions; 62 unwrappedRegions.reserve(nRegions); 63 for (intptr_t i = 0; i < nRegions; ++i) 64 unwrappedRegions.emplace_back(unwrap(*(regions + i))); 65 auto cleaner = llvm::make_scope_exit([&]() { 66 for (auto ®ion : unwrappedRegions) 67 region.release(); 68 }); 69 70 SmallVector<Type> inferredTypes; 71 if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes( 72 unwrap(context), maybeLocation, unwrappedOperands, attributeDict, 73 unwrappedRegions, inferredTypes))) 74 return mlirLogicalResultFailure(); 75 76 SmallVector<MlirType> wrappedInferredTypes; 77 wrappedInferredTypes.reserve(inferredTypes.size()); 78 for (Type t : inferredTypes) 79 wrappedInferredTypes.push_back(wrap(t)); 80 callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); 81 return mlirLogicalResultSuccess(); 82 } 83