1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "llvm/Support/FormatVariadic.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" 23 } // namespace mlir 24 25 bool ShapeAdaptor::hasRank() const { 26 if (val.isNull()) 27 return false; 28 if (auto t = val.dyn_cast<Type>()) 29 return t.cast<ShapedType>().hasRank(); 30 if (val.is<Attribute>()) 31 return true; 32 return val.get<ShapedTypeComponents *>()->hasRank(); 33 } 34 35 Type ShapeAdaptor::getElementType() const { 36 if (val.isNull()) 37 return nullptr; 38 if (auto t = val.dyn_cast<Type>()) 39 return t.cast<ShapedType>().getElementType(); 40 if (val.is<Attribute>()) 41 return nullptr; 42 return val.get<ShapedTypeComponents *>()->getElementType(); 43 } 44 45 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { 46 assert(hasRank()); 47 if (auto t = val.dyn_cast<Type>()) { 48 ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape(); 49 res.assign(vals.begin(), vals.end()); 50 } else if (auto attr = val.dyn_cast<Attribute>()) { 51 auto dattr = attr.cast<DenseIntElementsAttr>(); 52 res.clear(); 53 res.reserve(dattr.size()); 54 for (auto it : dattr.getValues<APInt>()) 55 res.push_back(it.getSExtValue()); 56 } else { 57 auto vals = val.get<ShapedTypeComponents *>()->getDims(); 58 res.assign(vals.begin(), vals.end()); 59 } 60 } 61 62 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { 63 assert(hasRank()); 64 res.ranked = true; 65 getDims(res.dims); 66 } 67 68 int64_t ShapeAdaptor::getDimSize(int index) const { 69 assert(hasRank()); 70 if (auto t = val.dyn_cast<Type>()) 71 return t.cast<ShapedType>().getDimSize(index); 72 if (auto attr = val.dyn_cast<Attribute>()) 73 return attr.cast<DenseIntElementsAttr>() 74 .getValues<APInt>()[index] 75 .getSExtValue(); 76 auto *stc = val.get<ShapedTypeComponents *>(); 77 return stc->getDims()[index]; 78 } 79 80 int64_t ShapeAdaptor::getRank() const { 81 assert(hasRank()); 82 if (auto t = val.dyn_cast<Type>()) 83 return t.cast<ShapedType>().getRank(); 84 if (auto attr = val.dyn_cast<Attribute>()) 85 return attr.cast<DenseIntElementsAttr>().size(); 86 return val.get<ShapedTypeComponents *>()->getDims().size(); 87 } 88 89 bool ShapeAdaptor::hasStaticShape() const { 90 if (!hasRank()) 91 return false; 92 93 if (auto t = val.dyn_cast<Type>()) 94 return t.cast<ShapedType>().hasStaticShape(); 95 if (auto attr = val.dyn_cast<Attribute>()) { 96 auto dattr = attr.cast<DenseIntElementsAttr>(); 97 for (auto index : dattr.getValues<APInt>()) 98 if (ShapedType::isDynamic(index.getSExtValue())) 99 return false; 100 return true; 101 } 102 auto *stc = val.get<ShapedTypeComponents *>(); 103 for (int64_t dim : stc->getDims()) 104 if (ShapedType::isDynamic(dim)) 105 return false; 106 return true; 107 } 108 109 int64_t ShapeAdaptor::getNumElements() const { 110 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 111 112 if (auto t = val.dyn_cast<Type>()) 113 return t.cast<ShapedType>().getNumElements(); 114 115 if (auto attr = val.dyn_cast<Attribute>()) { 116 auto dattr = attr.cast<DenseIntElementsAttr>(); 117 int64_t num = 1; 118 for (auto index : dattr.getValues<APInt>()) { 119 num *= index.getZExtValue(); 120 assert(num >= 0 && "integer overflow in element count computation"); 121 } 122 return num; 123 } 124 125 auto *stc = val.get<ShapedTypeComponents *>(); 126 int64_t num = 1; 127 for (int64_t dim : stc->getDims()) { 128 num *= dim; 129 assert(num >= 0 && "integer overflow in element count computation"); 130 } 131 return num; 132 } 133 134 void ShapeAdaptor::dump() const { 135 if (!hasRank()) { 136 llvm::errs() << "<<unranked>>\n"; 137 return; 138 } 139 140 SmallVector<int64_t> dims; 141 getDims(dims); 142 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { 143 if (ShapedType::isDynamic(dim)) 144 return "?"; 145 return llvm::formatv("{0}", dim).str(); 146 }); 147 llvm::errs() << "rank = " << getRank() << " dims = ["; 148 llvm::interleave(mapped, llvm::errs(), "x"); 149 llvm::errs() << "]\n"; 150 } 151 152 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { 153 Value val = operator[](index); 154 if (valueToShape) 155 if (ShapeAdaptor ret = valueToShape(val)) 156 return ret; 157 158 DenseIntElementsAttr attr; 159 if (!matchPattern(val, m_Constant(&attr))) 160 return nullptr; 161 if (attr.getType().getRank() != 1) 162 return nullptr; 163 return attr; 164 } 165 166 ShapeAdaptor ValueShapeRange::getShape(Value val) const { 167 if (operandShape) 168 if (ShapeAdaptor ret = operandShape(val)) 169 return ret; 170 return val.getType(); 171 } 172 173 ShapeAdaptor ValueShapeRange::getShape(int index) const { 174 if (index < 0 || static_cast<size_t>(index) >= size()) 175 return nullptr; 176 return getShape(operator[](index)); 177 } 178 179 LogicalResult mlir::detail::inferReturnTensorTypes( 180 function_ref<LogicalResult( 181 MLIRContext *, Optional<Location> location, ValueShapeRange operands, 182 DictionaryAttr attributes, RegionRange regions, 183 SmallVectorImpl<ShapedTypeComponents> &retComponents)> 184 componentTypeFn, 185 MLIRContext *context, Optional<Location> location, ValueRange operands, 186 DictionaryAttr attributes, RegionRange regions, 187 SmallVectorImpl<Type> &inferredReturnTypes) { 188 SmallVector<ShapedTypeComponents, 2> retComponents; 189 if (failed(componentTypeFn(context, location, operands, attributes, regions, 190 retComponents))) 191 return failure(); 192 for (const auto &shapeAndType : retComponents) { 193 assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); 194 if (shapeAndType.hasRank()) 195 inferredReturnTypes.push_back(RankedTensorType::get( 196 shapeAndType.getDims(), shapeAndType.getElementType())); 197 else 198 inferredReturnTypes.push_back( 199 UnrankedTensorType::get(shapeAndType.getElementType())); 200 } 201 return success(); 202 } 203 204 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { 205 SmallVector<Type, 4> inferredReturnTypes; 206 auto retTypeFn = cast<InferTypeOpInterface>(op); 207 if (failed(retTypeFn.inferReturnTypes( 208 op->getContext(), op->getLoc(), op->getOperands(), 209 op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) 210 return failure(); 211 if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, 212 op->getResultTypes())) 213 return op->emitOpError("inferred type(s) ") 214 << inferredReturnTypes 215 << " are incompatible with return type(s) of operation " 216 << op->getResultTypes(); 217 return success(); 218 } 219