xref: /llvm-project-15.0.7/mlir/lib/IR/Types.cpp (revision da121fff)
1 //===- Types.cpp - MLIR Type Classes --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/Types.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/Diagnostics.h"
12 #include "mlir/IR/Dialect.h"
13 #include "mlir/Support/LLVM.h"
14 #include "llvm/ADT/BitVector.h"
15 #include "llvm/ADT/Twine.h"
16 
17 using namespace mlir;
18 using namespace mlir::detail;
19 
20 //===----------------------------------------------------------------------===//
21 // Type
22 //===----------------------------------------------------------------------===//
23 
24 Dialect &Type::getDialect() const {
25   return impl->getAbstractType().getDialect();
26 }
27 
28 MLIRContext *Type::getContext() const { return getDialect().getContext(); }
29 
30 //===----------------------------------------------------------------------===//
31 // FunctionType
32 //===----------------------------------------------------------------------===//
33 
34 FunctionType FunctionType::get(TypeRange inputs, TypeRange results,
35                                MLIRContext *context) {
36   return Base::get(context, inputs, results);
37 }
38 
39 unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
40 
41 ArrayRef<Type> FunctionType::getInputs() const {
42   return getImpl()->getInputs();
43 }
44 
45 unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
46 
47 ArrayRef<Type> FunctionType::getResults() const {
48   return getImpl()->getResults();
49 }
50 
51 /// Helper to call a callback once on each index in the range
52 /// [0, `totalIndices`), *except* for the indices given in `indices`.
53 /// `indices` is allowed to have duplicates and can be in any order.
54 inline void iterateIndicesExcept(unsigned totalIndices,
55                                  ArrayRef<unsigned> indices,
56                                  function_ref<void(unsigned)> callback) {
57   llvm::BitVector skipIndices(totalIndices);
58   for (unsigned i : indices)
59     skipIndices.set(i);
60 
61   for (unsigned i = 0; i < totalIndices; ++i)
62     if (!skipIndices.test(i))
63       callback(i);
64 }
65 
66 /// Returns a new function type without the specified arguments and results.
67 FunctionType
68 FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
69                                        ArrayRef<unsigned> resultIndices) {
70   ArrayRef<Type> newInputTypes = getInputs();
71   SmallVector<Type, 4> newInputTypesBuffer;
72   if (!argIndices.empty()) {
73     unsigned originalNumArgs = getNumInputs();
74     iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) {
75       newInputTypesBuffer.emplace_back(getInput(i));
76     });
77     newInputTypes = newInputTypesBuffer;
78   }
79 
80   ArrayRef<Type> newResultTypes = getResults();
81   SmallVector<Type, 4> newResultTypesBuffer;
82   if (!resultIndices.empty()) {
83     unsigned originalNumResults = getNumResults();
84     iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) {
85       newResultTypesBuffer.emplace_back(getResult(i));
86     });
87     newResultTypes = newResultTypesBuffer;
88   }
89 
90   return get(newInputTypes, newResultTypes, getContext());
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // OpaqueType
95 //===----------------------------------------------------------------------===//
96 
97 OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
98                            MLIRContext *context) {
99   return Base::get(context, dialect, typeData);
100 }
101 
102 OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
103                                   MLIRContext *context, Location location) {
104   return Base::getChecked(location, dialect, typeData);
105 }
106 
107 /// Returns the dialect namespace of the opaque type.
108 Identifier OpaqueType::getDialectNamespace() const {
109   return getImpl()->dialectNamespace;
110 }
111 
112 /// Returns the raw type data of the opaque type.
113 StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
114 
115 /// Verify the construction of an opaque type.
116 LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
117                                                        Identifier dialect,
118                                                        StringRef typeData) {
119   if (!Dialect::isValidNamespace(dialect.strref()))
120     return emitError(loc, "invalid dialect namespace '") << dialect << "'";
121   return success();
122 }
123