1 //===- Types.cpp - MLIR Type Classes --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/Types.h" 10 #include "TypeDetail.h" 11 #include "mlir/IR/Diagnostics.h" 12 #include "mlir/IR/Dialect.h" 13 #include "mlir/Support/LLVM.h" 14 #include "llvm/ADT/BitVector.h" 15 #include "llvm/ADT/Twine.h" 16 17 using namespace mlir; 18 using namespace mlir::detail; 19 20 //===----------------------------------------------------------------------===// 21 // Type 22 //===----------------------------------------------------------------------===// 23 24 Dialect &Type::getDialect() const { 25 return impl->getAbstractType().getDialect(); 26 } 27 28 MLIRContext *Type::getContext() const { return getDialect().getContext(); } 29 30 //===----------------------------------------------------------------------===// 31 // FunctionType 32 //===----------------------------------------------------------------------===// 33 34 FunctionType FunctionType::get(TypeRange inputs, TypeRange results, 35 MLIRContext *context) { 36 return Base::get(context, inputs, results); 37 } 38 39 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } 40 41 ArrayRef<Type> FunctionType::getInputs() const { 42 return getImpl()->getInputs(); 43 } 44 45 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } 46 47 ArrayRef<Type> FunctionType::getResults() const { 48 return getImpl()->getResults(); 49 } 50 51 /// Helper to call a callback once on each index in the range 52 /// [0, `totalIndices`), *except* for the indices given in `indices`. 53 /// `indices` is allowed to have duplicates and can be in any order. 54 inline void iterateIndicesExcept(unsigned totalIndices, 55 ArrayRef<unsigned> indices, 56 function_ref<void(unsigned)> callback) { 57 llvm::BitVector skipIndices(totalIndices); 58 for (unsigned i : indices) 59 skipIndices.set(i); 60 61 for (unsigned i = 0; i < totalIndices; ++i) 62 if (!skipIndices.test(i)) 63 callback(i); 64 } 65 66 /// Returns a new function type without the specified arguments and results. 67 FunctionType 68 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices, 69 ArrayRef<unsigned> resultIndices) { 70 ArrayRef<Type> newInputTypes = getInputs(); 71 SmallVector<Type, 4> newInputTypesBuffer; 72 if (!argIndices.empty()) { 73 unsigned originalNumArgs = getNumInputs(); 74 iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { 75 newInputTypesBuffer.emplace_back(getInput(i)); 76 }); 77 newInputTypes = newInputTypesBuffer; 78 } 79 80 ArrayRef<Type> newResultTypes = getResults(); 81 SmallVector<Type, 4> newResultTypesBuffer; 82 if (!resultIndices.empty()) { 83 unsigned originalNumResults = getNumResults(); 84 iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { 85 newResultTypesBuffer.emplace_back(getResult(i)); 86 }); 87 newResultTypes = newResultTypesBuffer; 88 } 89 90 return get(newInputTypes, newResultTypes, getContext()); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // OpaqueType 95 //===----------------------------------------------------------------------===// 96 97 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, 98 MLIRContext *context) { 99 return Base::get(context, dialect, typeData); 100 } 101 102 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, 103 MLIRContext *context, Location location) { 104 return Base::getChecked(location, dialect, typeData); 105 } 106 107 /// Returns the dialect namespace of the opaque type. 108 Identifier OpaqueType::getDialectNamespace() const { 109 return getImpl()->dialectNamespace; 110 } 111 112 /// Returns the raw type data of the opaque type. 113 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } 114 115 /// Verify the construction of an opaque type. 116 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, 117 Identifier dialect, 118 StringRef typeData) { 119 if (!Dialect::isValidNamespace(dialect.strref())) 120 return emitError(loc, "invalid dialect namespace '") << dialect << "'"; 121 return success(); 122 } 123