1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from typing import Union 7 from ..ir import * 8 from ._ods_common import get_default_loc_context as _get_default_loc_context 9except ImportError as e: 10 raise RuntimeError("Error loading imports from extension module") from e 11 12from ._ml_program_ops_gen import * 13 14 15ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" 16RESULT_ATTRIBUTE_NAME = "res_attrs" 17 18 19class FuncOp: 20 """Specialization for the func op class.""" 21 22 def __init__(self, 23 name, 24 type, 25 *, 26 visibility=None, 27 body_builder=None, 28 loc=None, 29 ip=None): 30 """ 31 Create a FuncOp with the provided `name`, `type`, and `visibility`. 32 - `name` is a string representing the function name. 33 - `type` is either a FunctionType or a pair of list describing inputs and 34 results. 35 - `visibility` is a string matching `public`, `private`, or `nested`. None 36 implies private visibility. 37 - `body_builder` is an optional callback, when provided a new entry block 38 is created and the callback is invoked with the new op as argument within 39 an InsertionPoint context already set for the block. The callback is 40 expected to insert a terminator in the block. 41 """ 42 sym_name = StringAttr.get(str(name)) 43 44 # If the type is passed as a tuple, build a FunctionType on the fly. 45 if isinstance(type, tuple): 46 type = FunctionType.get(inputs=type[0], results=type[1]) 47 48 type = TypeAttr.get(type) 49 sym_visibility = StringAttr.get( 50 str(visibility)) if visibility is not None else None 51 super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) 52 if body_builder: 53 entry_block = self.add_entry_block() 54 with InsertionPoint(entry_block): 55 body_builder(self) 56 57 @property 58 def is_external(self): 59 return len(self.regions[0].blocks) == 0 60 61 @property 62 def body(self): 63 return self.regions[0] 64 65 @property 66 def type(self): 67 return FunctionType(TypeAttr(self.attributes["function_type"]).value) 68 69 @property 70 def visibility(self): 71 return self.attributes["sym_visibility"] 72 73 @property 74 def name(self) -> StringAttr: 75 return StringAttr(self.attributes["sym_name"]) 76 77 @property 78 def entry_block(self): 79 if self.is_external: 80 raise IndexError('External function does not have a body') 81 return self.regions[0].blocks[0] 82 83 def add_entry_block(self): 84 """ 85 Add an entry block to the function body using the function signature to 86 infer block arguments. 87 Returns the newly created block 88 """ 89 if not self.is_external: 90 raise IndexError('The function already has an entry block!') 91 self.body.blocks.append(*self.type.inputs) 92 return self.body.blocks[0] 93 94 @property 95 def arg_attrs(self): 96 return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) 97 98 @arg_attrs.setter 99 def arg_attrs(self, attribute: Union[ArrayAttr, list]): 100 if isinstance(attribute, ArrayAttr): 101 self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute 102 else: 103 self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( 104 attribute, context=self.context) 105 106 @property 107 def arguments(self): 108 return self.entry_block.arguments 109 110 @property 111 def result_attrs(self): 112 return self.attributes[RESULT_ATTRIBUTE_NAME] 113 114 @result_attrs.setter 115 def result_attrs(self, attribute: ArrayAttr): 116 self.attributes[RESULT_ATTRIBUTE_NAME] = attribute 117