1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/Matchers.h" 17 #include "llvm/Support/FormatVariadic.h" 18 19 using namespace mlir; 20 21 namespace mlir { 22 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" 23 } // namespace mlir 24 25 bool ShapeAdaptor::hasRank() const { 26 if (val.isNull()) 27 return false; 28 if (auto t = val.dyn_cast<Type>()) 29 return t.cast<ShapedType>().hasRank(); 30 if (val.is<Attribute>()) 31 return true; 32 return val.get<ShapedTypeComponents *>()->hasRank(); 33 } 34 35 Type ShapeAdaptor::getElementType() const { 36 if (val.isNull()) 37 return nullptr; 38 if (auto t = val.dyn_cast<Type>()) 39 return t.cast<ShapedType>().getElementType(); 40 if (val.is<Attribute>()) 41 return nullptr; 42 return val.get<ShapedTypeComponents *>()->getElementType(); 43 } 44 45 void ShapeAdaptor::getDims(SmallVectorImpl<int64_t> &res) const { 46 assert(hasRank()); 47 if (auto t = val.dyn_cast<Type>()) { 48 ArrayRef<int64_t> vals = t.cast<ShapedType>().getShape(); 49 res.assign(vals.begin(), vals.end()); 50 } else if (auto attr = val.dyn_cast<Attribute>()) { 51 auto dattr = attr.cast<DenseIntElementsAttr>(); 52 res.clear(); 53 res.reserve(dattr.size()); 54 for (auto it : dattr.getValues<APInt>()) 55 res.push_back(it.getSExtValue()); 56 } else { 57 auto vals = val.get<ShapedTypeComponents *>()->getDims(); 58 res.assign(vals.begin(), vals.end()); 59 } 60 } 61 62 void ShapeAdaptor::getDims(ShapedTypeComponents &res) const { 63 assert(hasRank()); 64 res.ranked = true; 65 getDims(res.dims); 66 } 67 68 int64_t ShapeAdaptor::getDimSize(int index) const { 69 assert(hasRank()); 70 if (auto t = val.dyn_cast<Type>()) 71 return t.cast<ShapedType>().getDimSize(index); 72 if (auto attr = val.dyn_cast<Attribute>()) 73 return attr.cast<DenseIntElementsAttr>() 74 .getValues<APInt>()[index] 75 .getSExtValue(); 76 auto *stc = val.get<ShapedTypeComponents *>(); 77 return stc->getDims()[index]; 78 } 79 80 int64_t ShapeAdaptor::getRank() const { 81 assert(hasRank()); 82 if (auto t = val.dyn_cast<Type>()) 83 return t.cast<ShapedType>().getRank(); 84 if (auto attr = val.dyn_cast<Attribute>()) 85 return attr.cast<DenseIntElementsAttr>().size(); 86 return val.get<ShapedTypeComponents *>()->getDims().size(); 87 } 88 89 bool ShapeAdaptor::hasStaticShape() const { 90 if (!hasRank()) 91 return false; 92 93 if (auto t = val.dyn_cast<Type>()) 94 return t.cast<ShapedType>().hasStaticShape(); 95 if (auto attr = val.dyn_cast<Attribute>()) { 96 auto dattr = attr.cast<DenseIntElementsAttr>(); 97 for (auto index : dattr.getValues<APInt>()) 98 if (ShapedType::isDynamic(index.getSExtValue())) 99 return false; 100 return true; 101 } 102 auto *stc = val.get<ShapedTypeComponents *>(); 103 return llvm::none_of(stc->getDims(), ShapedType::isDynamic); 104 } 105 106 int64_t ShapeAdaptor::getNumElements() const { 107 assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); 108 109 if (auto t = val.dyn_cast<Type>()) 110 return t.cast<ShapedType>().getNumElements(); 111 112 if (auto attr = val.dyn_cast<Attribute>()) { 113 auto dattr = attr.cast<DenseIntElementsAttr>(); 114 int64_t num = 1; 115 for (auto index : dattr.getValues<APInt>()) { 116 num *= index.getZExtValue(); 117 assert(num >= 0 && "integer overflow in element count computation"); 118 } 119 return num; 120 } 121 122 auto *stc = val.get<ShapedTypeComponents *>(); 123 int64_t num = 1; 124 for (int64_t dim : stc->getDims()) { 125 num *= dim; 126 assert(num >= 0 && "integer overflow in element count computation"); 127 } 128 return num; 129 } 130 131 void ShapeAdaptor::dump() const { 132 if (!hasRank()) { 133 llvm::errs() << "<<unranked>>\n"; 134 return; 135 } 136 137 SmallVector<int64_t> dims; 138 getDims(dims); 139 auto mapped = llvm::map_range(dims, [](int64_t dim) -> std::string { 140 if (ShapedType::isDynamic(dim)) 141 return "?"; 142 return llvm::formatv("{0}", dim).str(); 143 }); 144 llvm::errs() << "rank = " << getRank() << " dims = ["; 145 llvm::interleave(mapped, llvm::errs(), "x"); 146 llvm::errs() << "]\n"; 147 } 148 149 ShapeAdaptor ValueShapeRange::getValueAsShape(int index) { 150 Value val = operator[](index); 151 if (valueToShape) 152 if (ShapeAdaptor ret = valueToShape(val)) 153 return ret; 154 155 DenseIntElementsAttr attr; 156 if (!matchPattern(val, m_Constant(&attr))) 157 return nullptr; 158 if (attr.getType().getRank() != 1) 159 return nullptr; 160 return attr; 161 } 162 163 ShapeAdaptor ValueShapeRange::getShape(Value val) const { 164 if (operandShape) 165 if (ShapeAdaptor ret = operandShape(val)) 166 return ret; 167 return val.getType(); 168 } 169 170 ShapeAdaptor ValueShapeRange::getShape(int index) const { 171 if (index < 0 || static_cast<size_t>(index) >= size()) 172 return nullptr; 173 return getShape(operator[](index)); 174 } 175 176 LogicalResult mlir::detail::inferReturnTensorTypes( 177 function_ref<LogicalResult( 178 MLIRContext *, Optional<Location> location, ValueShapeRange operands, 179 DictionaryAttr attributes, RegionRange regions, 180 SmallVectorImpl<ShapedTypeComponents> &retComponents)> 181 componentTypeFn, 182 MLIRContext *context, Optional<Location> location, ValueRange operands, 183 DictionaryAttr attributes, RegionRange regions, 184 SmallVectorImpl<Type> &inferredReturnTypes) { 185 SmallVector<ShapedTypeComponents, 2> retComponents; 186 if (failed(componentTypeFn(context, location, operands, attributes, regions, 187 retComponents))) 188 return failure(); 189 for (const auto &shapeAndType : retComponents) { 190 assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); 191 assert(shapeAndType.getElementType() && 192 "element type required to construct tensor"); 193 if (shapeAndType.hasRank()) 194 inferredReturnTypes.push_back(RankedTensorType::get( 195 shapeAndType.getDims(), shapeAndType.getElementType())); 196 else 197 inferredReturnTypes.push_back( 198 UnrankedTensorType::get(shapeAndType.getElementType())); 199 } 200 return success(); 201 } 202 203 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { 204 SmallVector<Type, 4> inferredReturnTypes(op->getResultTypes()); 205 auto retTypeFn = cast<InferTypeOpInterface>(op); 206 return retTypeFn.refineReturnTypes(op->getContext(), op->getLoc(), 207 op->getOperands(), op->getAttrDictionary(), 208 op->getRegions(), inferredReturnTypes); 209 } 210