1 //===- InferTypeOpInterface.cpp - Infer Type Interfaces ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the definitions of the infer op interfaces defined in 10 // `InferTypeOpInterface.td`. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Interfaces/InferTypeOpInterface.h" 15 16 #include "mlir/IR/StandardTypes.h" 17 18 using namespace mlir; 19 20 namespace mlir { 21 #include "mlir/Interfaces/InferTypeOpInterface.cpp.inc" 22 } // namespace mlir 23 24 LogicalResult mlir::detail::inferReturnTensorTypes( 25 function_ref<LogicalResult( 26 MLIRContext *, Optional<Location> location, ValueRange operands, 27 DictionaryAttr attributes, RegionRange regions, 28 SmallVectorImpl<ShapedTypeComponents> &retComponents)> 29 componentTypeFn, 30 MLIRContext *context, Optional<Location> location, ValueRange operands, 31 DictionaryAttr attributes, RegionRange regions, 32 SmallVectorImpl<Type> &inferredReturnTypes) { 33 SmallVector<ShapedTypeComponents, 2> retComponents; 34 if (failed(componentTypeFn(context, location, operands, attributes, regions, 35 retComponents))) 36 return failure(); 37 for (auto shapeAndType : retComponents) { 38 assert(shapeAndType.getAttribute() == nullptr && "attribute not supported"); 39 if (shapeAndType.hasRank()) 40 inferredReturnTypes.push_back(RankedTensorType::get( 41 shapeAndType.getDims(), shapeAndType.getElementType())); 42 else 43 inferredReturnTypes.push_back( 44 UnrankedTensorType::get(shapeAndType.getElementType())); 45 } 46 return success(); 47 } 48 49 LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { 50 SmallVector<Type, 4> inferredReturnTypes; 51 auto retTypeFn = cast<InferTypeOpInterface>(op); 52 if (failed(retTypeFn.inferReturnTypes( 53 op->getContext(), op->getLoc(), op->getOperands(), 54 op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) 55 return failure(); 56 if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, 57 op->getResultTypes())) 58 return op->emitOpError( 59 "inferred type incompatible with return type of operation"); 60 return success(); 61 } 62