1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5try:
6  from typing import Sequence, Union
7  from ..ir import *
8  from ._ods_common import get_default_loc_context
9
10  from typing import Any, List, Union
11except ImportError as e:
12  raise RuntimeError("Error loading imports from extension module") from e
13
14
15class AllocTensorOp:
16  """Extends the bufferization.alloc_tensor op."""
17
18  def __init__(self,
19               tensor_type: Type,
20               dynamic_sizes: Sequence[Value],
21               copy: Value,
22               escape: BoolAttr,
23               *,
24               loc=None,
25               ip=None):
26    """Constructs an `alloc_tensor` with static and/or dynamic sizes."""
27    context = get_default_loc_context(loc)
28    attributes = {}
29    if escape:
30      attributes["escape"] = escape
31    op = self.build_generic(
32        results=[tensor_type],
33        operands=[dynamic_sizes, copy],
34        attributes=attributes,
35        loc=loc,
36        ip=ip)
37    OpView.__init__(self, op)
38