1# SPDX-License-Identifier: GPL-2.0
2
3import os
4import time
5from pathlib import Path
6from lib.py import KsftSkipEx, KsftXfailEx
7from lib.py import ksft_setup
8from lib.py import cmd, ethtool, ip, CmdExitFailure
9from lib.py import NetNS, NetdevSimDev
10from .remote import Remote
11
12
13class NetDrvEnvBase:
14    """
15    Base class for a NIC / host envirnoments
16    """
17    def __init__(self, src_path):
18        self.src_path = src_path
19        self.env = self._load_env_file()
20
21    def rpath(self, path):
22        """
23        Get an absolute path to a file based on a path relative to the directory
24        containing the test which constructed env.
25
26        For example, if the test.py is in the same directory as
27        a binary (built from helper.c), the test can use env.rpath("helper")
28        to get the absolute path to the binary
29        """
30        src_dir = Path(self.src_path).parent.resolve()
31        return (src_dir / path).as_posix()
32
33    def _load_env_file(self):
34        env = os.environ.copy()
35
36        src_dir = Path(self.src_path).parent.resolve()
37        if not (src_dir / "net.config").exists():
38            return ksft_setup(env)
39
40        with open((src_dir / "net.config").as_posix(), 'r') as fp:
41            for line in fp.readlines():
42                full_file = line
43                # Strip comments
44                pos = line.find("#")
45                if pos >= 0:
46                    line = line[:pos]
47                line = line.strip()
48                if not line:
49                    continue
50                pair = line.split('=', maxsplit=1)
51                if len(pair) != 2:
52                    raise Exception("Can't parse configuration line:", full_file)
53                env[pair[0]] = pair[1]
54        return ksft_setup(env)
55
56
57class NetDrvEnv(NetDrvEnvBase):
58    """
59    Class for a single NIC / host env, with no remote end
60    """
61    def __init__(self, src_path, **kwargs):
62        super().__init__(src_path)
63
64        self._ns = None
65
66        if 'NETIF' in self.env:
67            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
68        else:
69            self._ns = NetdevSimDev(**kwargs)
70            self.dev = self._ns.nsims[0].dev
71        self.ifname = self.dev['ifname']
72        self.ifindex = self.dev['ifindex']
73
74    def __enter__(self):
75        ip(f"link set dev {self.dev['ifname']} up")
76
77        return self
78
79    def __exit__(self, ex_type, ex_value, ex_tb):
80        """
81        __exit__ gets called at the end of a "with" block.
82        """
83        self.__del__()
84
85    def __del__(self):
86        if self._ns:
87            self._ns.remove()
88            self._ns = None
89
90
91class NetDrvEpEnv(NetDrvEnvBase):
92    """
93    Class for an environment with a local device and "remote endpoint"
94    which can be used to send traffic in.
95
96    For local testing it creates two network namespaces and a pair
97    of netdevsim devices.
98    """
99
100    # Network prefixes used for local tests
101    nsim_v4_pfx = "192.0.2."
102    nsim_v6_pfx = "2001:db8::"
103
104    def __init__(self, src_path, nsim_test=None):
105        super().__init__(src_path)
106
107        self._stats_settle_time = None
108
109        # Things we try to destroy
110        self.remote = None
111        # These are for local testing state
112        self._netns = None
113        self._ns = None
114        self._ns_peer = None
115
116        self.addr_v        = { "4": None, "6": None }
117        self.remote_addr_v = { "4": None, "6": None }
118
119        if "NETIF" in self.env:
120            if nsim_test is True:
121                raise KsftXfailEx("Test only works on netdevsim")
122            self._check_env()
123
124            self.dev = ip("-d link show dev " + self.env['NETIF'], json=True)[0]
125
126            self.addr_v["4"] = self.env.get("LOCAL_V4")
127            self.addr_v["6"] = self.env.get("LOCAL_V6")
128            self.remote_addr_v["4"] = self.env.get("REMOTE_V4")
129            self.remote_addr_v["6"] = self.env.get("REMOTE_V6")
130            kind = self.env["REMOTE_TYPE"]
131            args = self.env["REMOTE_ARGS"]
132        else:
133            if nsim_test is False:
134                raise KsftXfailEx("Test does not work on netdevsim")
135
136            self.create_local()
137
138            self.dev = self._ns.nsims[0].dev
139
140            self.addr_v["4"] = self.nsim_v4_pfx + "1"
141            self.addr_v["6"] = self.nsim_v6_pfx + "1"
142            self.remote_addr_v["4"] = self.nsim_v4_pfx + "2"
143            self.remote_addr_v["6"] = self.nsim_v6_pfx + "2"
144            kind = "netns"
145            args = self._netns.name
146
147        self.remote = Remote(kind, args, src_path)
148
149        self.addr_ipver = "6" if self.addr_v["6"] else "4"
150        self.addr = self.addr_v[self.addr_ipver]
151        self.remote_addr = self.remote_addr_v[self.addr_ipver]
152
153        # Bracketed addresses, some commands need IPv6 to be inside []
154        self.baddr = f"[{self.addr_v['6']}]" if self.addr_v["6"] else self.addr_v["4"]
155        self.remote_baddr = f"[{self.remote_addr_v['6']}]" if self.remote_addr_v["6"] else self.remote_addr_v["4"]
156
157        self.ifname = self.dev['ifname']
158        self.ifindex = self.dev['ifindex']
159
160        # resolve remote interface name
161        self.remote_ifname = self.resolve_remote_ifc()
162
163        self._required_cmd = {}
164
165    def create_local(self):
166        self._netns = NetNS()
167        self._ns = NetdevSimDev()
168        self._ns_peer = NetdevSimDev(ns=self._netns)
169
170        with open("/proc/self/ns/net") as nsfd0, \
171             open("/var/run/netns/" + self._netns.name) as nsfd1:
172            ifi0 = self._ns.nsims[0].ifindex
173            ifi1 = self._ns_peer.nsims[0].ifindex
174            NetdevSimDev.ctrl_write('link_device',
175                                    f'{nsfd0.fileno()}:{ifi0} {nsfd1.fileno()}:{ifi1}')
176
177        ip(f"   addr add dev {self._ns.nsims[0].ifname} {self.nsim_v4_pfx}1/24")
178        ip(f"-6 addr add dev {self._ns.nsims[0].ifname} {self.nsim_v6_pfx}1/64 nodad")
179        ip(f"   link set dev {self._ns.nsims[0].ifname} up")
180
181        ip(f"   addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v4_pfx}2/24", ns=self._netns)
182        ip(f"-6 addr add dev {self._ns_peer.nsims[0].ifname} {self.nsim_v6_pfx}2/64 nodad", ns=self._netns)
183        ip(f"   link set dev {self._ns_peer.nsims[0].ifname} up", ns=self._netns)
184
185    def _check_env(self):
186        vars_needed = [
187            ["LOCAL_V4", "LOCAL_V6"],
188            ["REMOTE_V4", "REMOTE_V6"],
189            ["REMOTE_TYPE"],
190            ["REMOTE_ARGS"]
191        ]
192        missing = []
193
194        for choice in vars_needed:
195            for entry in choice:
196                if entry in self.env:
197                    break
198            else:
199                missing.append(choice)
200        # Make sure v4 / v6 configs are symmetric
201        if ("LOCAL_V6" in self.env) != ("REMOTE_V6" in self.env):
202            missing.append(["LOCAL_V6", "REMOTE_V6"])
203        if ("LOCAL_V4" in self.env) != ("REMOTE_V4" in self.env):
204            missing.append(["LOCAL_V4", "REMOTE_V4"])
205        if missing:
206            raise Exception("Invalid environment, missing configuration:", missing,
207                            "Please see tools/testing/selftests/drivers/net/README.rst")
208
209    def resolve_remote_ifc(self):
210        v4 = v6 = None
211        if self.remote_addr_v["4"]:
212            v4 = ip("addr show to " + self.remote_addr_v["4"], json=True, host=self.remote)
213        if self.remote_addr_v["6"]:
214            v6 = ip("addr show to " + self.remote_addr_v["6"], json=True, host=self.remote)
215        if v4 and v6 and v4[0]["ifname"] != v6[0]["ifname"]:
216            raise Exception("Can't resolve remote interface name, v4 and v6 don't match")
217        if (v4 and len(v4) > 1) or (v6 and len(v6) > 1):
218            raise Exception("Can't resolve remote interface name, multiple interfaces match")
219        return v6[0]["ifname"] if v6 else v4[0]["ifname"]
220
221    def __enter__(self):
222        return self
223
224    def __exit__(self, ex_type, ex_value, ex_tb):
225        """
226        __exit__ gets called at the end of a "with" block.
227        """
228        self.__del__()
229
230    def __del__(self):
231        if self._ns:
232            self._ns.remove()
233            self._ns = None
234        if self._ns_peer:
235            self._ns_peer.remove()
236            self._ns_peer = None
237        if self._netns:
238            del self._netns
239            self._netns = None
240        if self.remote:
241            del self.remote
242            self.remote = None
243
244    def require_ipver(self, ipver):
245        if not self.addr_v[ipver] or not self.remote_addr_v[ipver]:
246            raise KsftSkipEx(f"Test requires IPv{ipver} connectivity")
247
248    def _require_cmd(self, comm, key, host=None):
249        cached = self._required_cmd.get(comm, {})
250        if cached.get(key) is None:
251            cached[key] = cmd("command -v -- " + comm, fail=False,
252                              shell=True, host=host).ret == 0
253        self._required_cmd[comm] = cached
254        return cached[key]
255
256    def require_cmd(self, comm, local=True, remote=False):
257        if local:
258            if not self._require_cmd(comm, "local"):
259                raise KsftSkipEx("Test requires command: " + comm)
260        if remote:
261            if not self._require_cmd(comm, "remote"):
262                raise KsftSkipEx("Test requires (remote) command: " + comm)
263
264    def wait_hw_stats_settle(self):
265        """
266        Wait for HW stats to become consistent, some devices DMA HW stats
267        periodically so events won't be reflected until next sync.
268        Good drivers will tell us via ethtool what their sync period is.
269        """
270        if self._stats_settle_time is None:
271            data = {}
272            try:
273                data = ethtool("-c " + self.ifname, json=True)[0]
274            except CmdExitFailure as e:
275                if "Operation not supported" not in e.cmd.stderr:
276                    raise
277
278            self._stats_settle_time = 0.025 + \
279                data.get('stats-block-usecs', 0) / 1000 / 1000
280
281        time.sleep(self._stats_settle_time)
282