1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5try: 6 from typing import Sequence, Union 7 from ..ir import * 8 from ._ods_common import get_default_loc_context 9 10 from typing import Any, List, Union 11except ImportError as e: 12 raise RuntimeError("Error loading imports from extension module") from e 13 14 15class AllocTensorOp: 16 """Extends the bufferization.alloc_tensor op.""" 17 18 def __init__(self, 19 tensor_type: Type, 20 dynamic_sizes: Sequence[Value], 21 copy: Value, 22 escape: BoolAttr, 23 *, 24 loc=None, 25 ip=None): 26 """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" 27 context = get_default_loc_context(loc) 28 attributes = {} 29 if escape: 30 attributes["escape"] = escape 31 op = self.build_generic( 32 results=[tensor_type], 33 operands=[dynamic_sizes, copy], 34 attributes=attributes, 35 loc=loc, 36 ip=ip) 37 OpView.__init__(self, op) 38