1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from ..ir import * 7 from ._ods_common import get_default_loc_context as _get_default_loc_context 8 9 import inspect 10 11 from typing import Any, List, Optional, Sequence, Union 12except ImportError as e: 13 raise RuntimeError("Error loading imports from extension module") from e 14 15ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" 16RESULT_ATTRIBUTE_NAME = "res_attrs" 17 18class ConstantOp: 19 """Specialization for the constant op class.""" 20 21 def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): 22 super().__init__(result, value, loc=loc, ip=ip) 23 24 @property 25 def type(self): 26 return self.results[0].type 27 28 29class FuncOp: 30 """Specialization for the func op class.""" 31 32 def __init__(self, 33 name, 34 type, 35 *, 36 visibility=None, 37 body_builder=None, 38 loc=None, 39 ip=None): 40 """ 41 Create a FuncOp with the provided `name`, `type`, and `visibility`. 42 - `name` is a string representing the function name. 43 - `type` is either a FunctionType or a pair of list describing inputs and 44 results. 45 - `visibility` is a string matching `public`, `private`, or `nested`. None 46 implies private visibility. 47 - `body_builder` is an optional callback, when provided a new entry block 48 is created and the callback is invoked with the new op as argument within 49 an InsertionPoint context already set for the block. The callback is 50 expected to insert a terminator in the block. 51 """ 52 sym_name = StringAttr.get(str(name)) 53 54 # If the type is passed as a tuple, build a FunctionType on the fly. 55 if isinstance(type, tuple): 56 type = FunctionType.get(inputs=type[0], results=type[1]) 57 58 type = TypeAttr.get(type) 59 sym_visibility = StringAttr.get( 60 str(visibility)) if visibility is not None else None 61 super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) 62 if body_builder: 63 entry_block = self.add_entry_block() 64 with InsertionPoint(entry_block): 65 body_builder(self) 66 67 @property 68 def is_external(self): 69 return len(self.regions[0].blocks) == 0 70 71 @property 72 def body(self): 73 return self.regions[0] 74 75 @property 76 def type(self): 77 return FunctionType(TypeAttr(self.attributes["function_type"]).value) 78 79 @property 80 def visibility(self): 81 return self.attributes["sym_visibility"] 82 83 @property 84 def name(self) -> StringAttr: 85 return StringAttr(self.attributes["sym_name"]) 86 87 @property 88 def entry_block(self): 89 if self.is_external: 90 raise IndexError('External function does not have a body') 91 return self.regions[0].blocks[0] 92 93 def add_entry_block(self): 94 """ 95 Add an entry block to the function body using the function signature to 96 infer block arguments. 97 Returns the newly created block 98 """ 99 if not self.is_external: 100 raise IndexError('The function already has an entry block!') 101 self.body.blocks.append(*self.type.inputs) 102 return self.body.blocks[0] 103 104 @property 105 def arg_attrs(self): 106 return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) 107 108 @arg_attrs.setter 109 def arg_attrs(self, attribute: Union[ArrayAttr, list]): 110 if isinstance(attribute, ArrayAttr): 111 self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute 112 else: 113 self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( 114 attribute, context=self.context) 115 116 @property 117 def arguments(self): 118 return self.entry_block.arguments 119 120 @property 121 def result_attrs(self): 122 return self.attributes[RESULT_ATTRIBUTE_NAME] 123 124 @result_attrs.setter 125 def result_attrs(self, attribute: ArrayAttr): 126 self.attributes[RESULT_ATTRIBUTE_NAME] = attribute 127 128 @classmethod 129 def from_py_func(FuncOp, 130 *inputs: Type, 131 results: Optional[Sequence[Type]] = None, 132 name: Optional[str] = None): 133 """Decorator to define an MLIR FuncOp specified as a python function. 134 135 Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are 136 active for the current thread (i.e. established in a `with` block). 137 138 When applied as a decorator to a Python function, an entry block will 139 be constructed for the FuncOp with types as specified in `*inputs`. The 140 block arguments will be passed positionally to the Python function. In 141 addition, if the Python function accepts keyword arguments generally or 142 has a corresponding keyword argument, the following will be passed: 143 * `func_op`: The `func` op being defined. 144 145 By default, the function name will be the Python function `__name__`. This 146 can be overriden by passing the `name` argument to the decorator. 147 148 If `results` is not specified, then the decorator will implicitly 149 insert a `ReturnOp` with the `Value`'s returned from the decorated 150 function. It will also set the `FuncOp` type with the actual return 151 value types. If `results` is specified, then the decorated function 152 must return `None` and no implicit `ReturnOp` is added (nor are the result 153 types updated). The implicit behavior is intended for simple, single-block 154 cases, and users should specify result types explicitly for any complicated 155 cases. 156 157 The decorated function can further be called from Python and will insert 158 a `CallOp` at the then-current insertion point, returning either None ( 159 if no return values), a unary Value (for one result), or a list of Values). 160 This mechanism cannot be used to emit recursive calls (by construction). 161 """ 162 163 def decorator(f): 164 from . import func 165 # Introspect the callable for optional features. 166 sig = inspect.signature(f) 167 has_arg_func_op = False 168 for param in sig.parameters.values(): 169 if param.kind == param.VAR_KEYWORD: 170 has_arg_func_op = True 171 if param.name == "func_op" and (param.kind 172 == param.POSITIONAL_OR_KEYWORD or 173 param.kind == param.KEYWORD_ONLY): 174 has_arg_func_op = True 175 176 # Emit the FuncOp. 177 implicit_return = results is None 178 symbol_name = name or f.__name__ 179 function_type = FunctionType.get( 180 inputs=inputs, results=[] if implicit_return else results) 181 func_op = FuncOp(name=symbol_name, type=function_type) 182 with InsertionPoint(func_op.add_entry_block()): 183 func_args = func_op.entry_block.arguments 184 func_kwargs = {} 185 if has_arg_func_op: 186 func_kwargs["func_op"] = func_op 187 return_values = f(*func_args, **func_kwargs) 188 if not implicit_return: 189 return_types = list(results) 190 assert return_values is None, ( 191 "Capturing a python function with explicit `results=` " 192 "requires that the wrapped function returns None.") 193 else: 194 # Coerce return values, add ReturnOp and rewrite func type. 195 if return_values is None: 196 return_values = [] 197 elif isinstance(return_values, tuple): 198 return_values = list(return_values) 199 elif isinstance(return_values, Value): 200 # Returning a single value is fine, coerce it into a list. 201 return_values = [return_values] 202 elif isinstance(return_values, OpView): 203 # Returning a single operation is fine, coerce its results a list. 204 return_values = return_values.operation.results 205 elif isinstance(return_values, Operation): 206 # Returning a single operation is fine, coerce its results a list. 207 return_values = return_values.results 208 else: 209 return_values = list(return_values) 210 func.ReturnOp(return_values) 211 # Recompute the function type. 212 return_types = [v.type for v in return_values] 213 function_type = FunctionType.get(inputs=inputs, results=return_types) 214 func_op.attributes["function_type"] = TypeAttr.get(function_type) 215 216 def emit_call_op(*call_args): 217 call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), 218 call_args) 219 if return_types is None: 220 return None 221 elif len(return_types) == 1: 222 return call_op.result 223 else: 224 return call_op.results 225 226 wrapped = emit_call_op 227 wrapped.__name__ = f.__name__ 228 wrapped.func_op = func_op 229 return wrapped 230 231 return decorator 232 233class CallOp: 234 """Specialization for the call op class.""" 235 236 def __init__(self, 237 calleeOrResults: Union[FuncOp, List[Type]], 238 argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], 239 arguments: Optional[List] = None, 240 *, 241 loc=None, 242 ip=None): 243 """Creates an call operation. 244 245 The constructor accepts three different forms: 246 247 1. A function op to be called followed by a list of arguments. 248 2. A list of result types, followed by the name of the function to be 249 called as string, following by a list of arguments. 250 3. A list of result types, followed by the name of the function to be 251 called as symbol reference attribute, followed by a list of arguments. 252 253 For example 254 255 f = func.FuncOp("foo", ...) 256 func.CallOp(f, [args]) 257 func.CallOp([result_types], "foo", [args]) 258 259 In all cases, the location and insertion point may be specified as keyword 260 arguments if not provided by the surrounding context managers. 261 """ 262 263 # TODO: consider supporting constructor "overloads", e.g., through a custom 264 # or pybind-provided metaclass. 265 if isinstance(calleeOrResults, FuncOp): 266 if not isinstance(argumentsOrCallee, list): 267 raise ValueError( 268 "when constructing a call to a function, expected " + 269 "the second argument to be a list of call arguments, " + 270 f"got {type(argumentsOrCallee)}") 271 if arguments is not None: 272 raise ValueError("unexpected third argument when constructing a call" + 273 "to a function") 274 275 super().__init__( 276 calleeOrResults.type.results, 277 FlatSymbolRefAttr.get( 278 calleeOrResults.name.value, 279 context=_get_default_loc_context(loc)), 280 argumentsOrCallee, 281 loc=loc, 282 ip=ip) 283 return 284 285 if isinstance(argumentsOrCallee, list): 286 raise ValueError("when constructing a call to a function by name, " + 287 "expected the second argument to be a string or a " + 288 f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}") 289 290 if isinstance(argumentsOrCallee, FlatSymbolRefAttr): 291 super().__init__( 292 calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) 293 elif isinstance(argumentsOrCallee, str): 294 super().__init__( 295 calleeOrResults, 296 FlatSymbolRefAttr.get( 297 argumentsOrCallee, context=_get_default_loc_context(loc)), 298 arguments, 299 loc=loc, 300 ip=ip) 301