1 //===- FuncOps.cpp - Func Dialect Operations ------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" 12 #include "mlir/IR/BlockAndValueMapping.h" 13 #include "mlir/IR/Builders.h" 14 #include "mlir/IR/BuiltinOps.h" 15 #include "mlir/IR/BuiltinTypes.h" 16 #include "mlir/IR/FunctionImplementation.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/OpImplementation.h" 19 #include "mlir/IR/PatternMatch.h" 20 #include "mlir/IR/TypeUtilities.h" 21 #include "mlir/IR/Value.h" 22 #include "mlir/Support/MathExtras.h" 23 #include "mlir/Transforms/InliningUtils.h" 24 #include "llvm/ADT/APFloat.h" 25 #include "llvm/ADT/MapVector.h" 26 #include "llvm/ADT/STLExtras.h" 27 #include "llvm/ADT/StringSwitch.h" 28 #include "llvm/Support/FormatVariadic.h" 29 #include "llvm/Support/raw_ostream.h" 30 #include <numeric> 31 32 #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" 33 34 using namespace mlir; 35 using namespace mlir::func; 36 37 //===----------------------------------------------------------------------===// 38 // FuncDialect Interfaces 39 //===----------------------------------------------------------------------===// 40 namespace { 41 /// This class defines the interface for handling inlining with func operations. 42 struct FuncInlinerInterface : public DialectInlinerInterface { 43 using DialectInlinerInterface::DialectInlinerInterface; 44 45 //===--------------------------------------------------------------------===// 46 // Analysis Hooks 47 //===--------------------------------------------------------------------===// 48 49 /// All call operations can be inlined. 50 bool isLegalToInline(Operation *call, Operation *callable, 51 bool wouldBeCloned) const final { 52 return true; 53 } 54 55 /// All operations can be inlined. 56 bool isLegalToInline(Operation *, Region *, bool, 57 BlockAndValueMapping &) const final { 58 return true; 59 } 60 61 /// All functions can be inlined. 62 bool isLegalToInline(Region *, Region *, bool, 63 BlockAndValueMapping &) const final { 64 return true; 65 } 66 67 //===--------------------------------------------------------------------===// 68 // Transformation Hooks 69 //===--------------------------------------------------------------------===// 70 71 /// Handle the given inlined terminator by replacing it with a new operation 72 /// as necessary. 73 void handleTerminator(Operation *op, Block *newDest) const final { 74 // Only return needs to be handled here. 75 auto returnOp = dyn_cast<ReturnOp>(op); 76 if (!returnOp) 77 return; 78 79 // Replace the return with a branch to the dest. 80 OpBuilder builder(op); 81 builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands()); 82 op->erase(); 83 } 84 85 /// Handle the given inlined terminator by replacing it with a new operation 86 /// as necessary. 87 void handleTerminator(Operation *op, 88 ArrayRef<Value> valuesToRepl) const final { 89 // Only return needs to be handled here. 90 auto returnOp = cast<ReturnOp>(op); 91 92 // Replace the values directly with the return operands. 93 assert(returnOp.getNumOperands() == valuesToRepl.size()); 94 for (const auto &it : llvm::enumerate(returnOp.getOperands())) 95 valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 96 } 97 }; 98 } // namespace 99 100 //===----------------------------------------------------------------------===// 101 // FuncDialect 102 //===----------------------------------------------------------------------===// 103 104 void FuncDialect::initialize() { 105 addOperations< 106 #define GET_OP_LIST 107 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 108 >(); 109 addInterfaces<FuncInlinerInterface>(); 110 } 111 112 /// Materialize a single constant operation from a given attribute value with 113 /// the desired resultant type. 114 Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, 115 Type type, Location loc) { 116 if (ConstantOp::isBuildableWith(value, type)) 117 return builder.create<ConstantOp>(loc, type, 118 value.cast<FlatSymbolRefAttr>()); 119 return nullptr; 120 } 121 122 //===----------------------------------------------------------------------===// 123 // CallOp 124 //===----------------------------------------------------------------------===// 125 126 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { 127 // Check that the callee attribute was specified. 128 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee"); 129 if (!fnAttr) 130 return emitOpError("requires a 'callee' symbol reference attribute"); 131 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr); 132 if (!fn) 133 return emitOpError() << "'" << fnAttr.getValue() 134 << "' does not reference a valid function"; 135 136 // Verify that the operand and result types match the callee. 137 auto fnType = fn.getFunctionType(); 138 if (fnType.getNumInputs() != getNumOperands()) 139 return emitOpError("incorrect number of operands for callee"); 140 141 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) 142 if (getOperand(i).getType() != fnType.getInput(i)) 143 return emitOpError("operand type mismatch: expected operand type ") 144 << fnType.getInput(i) << ", but provided " 145 << getOperand(i).getType() << " for operand number " << i; 146 147 if (fnType.getNumResults() != getNumResults()) 148 return emitOpError("incorrect number of results for callee"); 149 150 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) 151 if (getResult(i).getType() != fnType.getResult(i)) { 152 auto diag = emitOpError("result type mismatch at index ") << i; 153 diag.attachNote() << " op result types: " << getResultTypes(); 154 diag.attachNote() << "function result types: " << fnType.getResults(); 155 return diag; 156 } 157 158 return success(); 159 } 160 161 FunctionType CallOp::getCalleeType() { 162 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); 163 } 164 165 //===----------------------------------------------------------------------===// 166 // CallIndirectOp 167 //===----------------------------------------------------------------------===// 168 169 /// Fold indirect calls that have a constant function as the callee operand. 170 LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, 171 PatternRewriter &rewriter) { 172 // Check that the callee is a constant callee. 173 SymbolRefAttr calledFn; 174 if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) 175 return failure(); 176 177 // Replace with a direct call. 178 rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn, 179 indirectCall.getResultTypes(), 180 indirectCall.getArgOperands()); 181 return success(); 182 } 183 184 //===----------------------------------------------------------------------===// 185 // ConstantOp 186 //===----------------------------------------------------------------------===// 187 188 LogicalResult ConstantOp::verify() { 189 StringRef fnName = getValue(); 190 Type type = getType(); 191 192 // Try to find the referenced function. 193 auto fn = (*this)->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnName); 194 if (!fn) 195 return emitOpError() << "reference to undefined function '" << fnName 196 << "'"; 197 198 // Check that the referenced function has the correct type. 199 if (fn.getFunctionType() != type) 200 return emitOpError("reference to function with mismatched type"); 201 202 return success(); 203 } 204 205 OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) { 206 assert(operands.empty() && "constant has no operands"); 207 return getValueAttr(); 208 } 209 210 void ConstantOp::getAsmResultNames( 211 function_ref<void(Value, StringRef)> setNameFn) { 212 setNameFn(getResult(), "f"); 213 } 214 215 bool ConstantOp::isBuildableWith(Attribute value, Type type) { 216 return value.isa<FlatSymbolRefAttr>() && type.isa<FunctionType>(); 217 } 218 219 //===----------------------------------------------------------------------===// 220 // FuncOp 221 //===----------------------------------------------------------------------===// 222 223 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 224 ArrayRef<NamedAttribute> attrs) { 225 OpBuilder builder(location->getContext()); 226 OperationState state(location, getOperationName()); 227 FuncOp::build(builder, state, name, type, attrs); 228 return cast<FuncOp>(Operation::create(state)); 229 } 230 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 231 Operation::dialect_attr_range attrs) { 232 SmallVector<NamedAttribute, 8> attrRef(attrs); 233 return create(location, name, type, llvm::makeArrayRef(attrRef)); 234 } 235 FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, 236 ArrayRef<NamedAttribute> attrs, 237 ArrayRef<DictionaryAttr> argAttrs) { 238 FuncOp func = create(location, name, type, attrs); 239 func.setAllArgAttrs(argAttrs); 240 return func; 241 } 242 243 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, 244 FunctionType type, ArrayRef<NamedAttribute> attrs, 245 ArrayRef<DictionaryAttr> argAttrs) { 246 state.addAttribute(SymbolTable::getSymbolAttrName(), 247 builder.getStringAttr(name)); 248 state.addAttribute(FunctionOpInterface::getTypeAttrName(), 249 TypeAttr::get(type)); 250 state.attributes.append(attrs.begin(), attrs.end()); 251 state.addRegion(); 252 253 if (argAttrs.empty()) 254 return; 255 assert(type.getNumInputs() == argAttrs.size()); 256 function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs, 257 /*resultAttrs=*/llvm::None); 258 } 259 260 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { 261 auto buildFuncType = 262 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results, 263 function_interface_impl::VariadicFlag, 264 std::string &) { return builder.getFunctionType(argTypes, results); }; 265 266 return function_interface_impl::parseFunctionOp( 267 parser, result, /*allowVariadic=*/false, buildFuncType); 268 } 269 270 void FuncOp::print(OpAsmPrinter &p) { 271 function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); 272 } 273 274 /// Clone the internal blocks from this function into dest and all attributes 275 /// from this function to dest. 276 void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { 277 // Add the attributes of this function to dest. 278 llvm::MapVector<StringAttr, Attribute> newAttrMap; 279 for (const auto &attr : dest->getAttrs()) 280 newAttrMap.insert({attr.getName(), attr.getValue()}); 281 for (const auto &attr : (*this)->getAttrs()) 282 newAttrMap.insert({attr.getName(), attr.getValue()}); 283 284 auto newAttrs = llvm::to_vector(llvm::map_range( 285 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) { 286 return NamedAttribute(attrPair.first, attrPair.second); 287 })); 288 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); 289 290 // Clone the body. 291 getBody().cloneInto(&dest.getBody(), mapper); 292 } 293 294 /// Create a deep copy of this function and all of its blocks, remapping 295 /// any operands that use values outside of the function using the map that is 296 /// provided (leaving them alone if no entry is present). Replaces references 297 /// to cloned sub-values with the corresponding value that is copied, and adds 298 /// those mappings to the mapper. 299 FuncOp FuncOp::clone(BlockAndValueMapping &mapper) { 300 // Create the new function. 301 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions()); 302 303 // If the function has a body, then the user might be deleting arguments to 304 // the function by specifying them in the mapper. If so, we don't add the 305 // argument to the input type vector. 306 if (!isExternal()) { 307 FunctionType oldType = getFunctionType(); 308 309 unsigned oldNumArgs = oldType.getNumInputs(); 310 SmallVector<Type, 4> newInputs; 311 newInputs.reserve(oldNumArgs); 312 for (unsigned i = 0; i != oldNumArgs; ++i) 313 if (!mapper.contains(getArgument(i))) 314 newInputs.push_back(oldType.getInput(i)); 315 316 /// If any of the arguments were dropped, update the type and drop any 317 /// necessary argument attributes. 318 if (newInputs.size() != oldNumArgs) { 319 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, 320 oldType.getResults())); 321 322 if (ArrayAttr argAttrs = getAllArgAttrs()) { 323 SmallVector<Attribute> newArgAttrs; 324 newArgAttrs.reserve(newInputs.size()); 325 for (unsigned i = 0; i != oldNumArgs; ++i) 326 if (!mapper.contains(getArgument(i))) 327 newArgAttrs.push_back(argAttrs[i]); 328 newFunc.setAllArgAttrs(newArgAttrs); 329 } 330 } 331 } 332 333 /// Clone the current function into the new one and return it. 334 cloneInto(newFunc, mapper); 335 return newFunc; 336 } 337 FuncOp FuncOp::clone() { 338 BlockAndValueMapping mapper; 339 return clone(mapper); 340 } 341 342 //===----------------------------------------------------------------------===// 343 // ReturnOp 344 //===----------------------------------------------------------------------===// 345 346 LogicalResult ReturnOp::verify() { 347 auto function = cast<FuncOp>((*this)->getParentOp()); 348 349 // The operand number and types must match the function signature. 350 const auto &results = function.getFunctionType().getResults(); 351 if (getNumOperands() != results.size()) 352 return emitOpError("has ") 353 << getNumOperands() << " operands, but enclosing function (@" 354 << function.getName() << ") returns " << results.size(); 355 356 for (unsigned i = 0, e = results.size(); i != e; ++i) 357 if (getOperand(i).getType() != results[i]) 358 return emitError() << "type of return operand " << i << " (" 359 << getOperand(i).getType() 360 << ") doesn't match function result type (" 361 << results[i] << ")" 362 << " in function @" << function.getName(); 363 364 return success(); 365 } 366 367 //===----------------------------------------------------------------------===// 368 // TableGen'd op method definitions 369 //===----------------------------------------------------------------------===// 370 371 #define GET_OP_CLASSES 372 #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" 373