1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 17 using namespace mlir; 18 19 namespace mlir { 20 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" 21 } // namespace mlir 22 23 LogicalResult mlir::detail::inferReturnTensorTypes( 24 function_ref<LogicalResult( 25 MLIRContext *, Optional<Location> location, ValueRange operands, 26 DictionaryAttr attributes, RegionRange regions, 27 SmallVectorImpl<ShapedTypeComponents> &retComponents)> 28 componentTypeFn, 29 MLIRContext *context, Optional<Location> location, ValueRange operands, 30 DictionaryAttr attributes, RegionRange regions, 31 SmallVectorImpl<Type> &inferredReturnTypes) { 32 SmallVector<ShapedTypeComponents, 2> retComponents; 33 if (failed(componentTypeFn(context, location, operands, attributes, regions, 34 retComponents))) 35 return failure(); 36 for (auto shapeAndType : retComponents) { 37 assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); 38 if (shapeAndType.hasRank()) 39 inferredReturnTypes.push_back(RankedTensorType::get( 40 shapeAndType.getDims(), shapeAndType.getElementType())); 41 else 42 inferredReturnTypes.push_back( 43 UnrankedTensorType::get(shapeAndType.getElementType())); 44 } 45 return success(); 46 } 47 48 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { 49 SmallVector<Type, 4> inferredReturnTypes; 50 auto retTypeFn = cast<InferTypeOpInterface>(op); 51 if (failed(retTypeFn.inferReturnTypes( 52 op->getContext(), op->getLoc(), op->getOperands(), 53 op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) 54 return failure(); 55 if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, 56 op->getResultTypes())) 57 return op->emitOpError("inferred type(s) ") 58 << inferredReturnTypes 59 << " are incompatible with return type(s) of operation " 60 << op->getResultTypes(); 61 return success(); 62 } 63