1# DExTer : Debugging Experience Tester
2# ~~~~~~   ~         ~~         ~   ~~
3#
4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7"""Parse a DExTer command. In particular, ensure that only a very limited
8subset of Python is allowed, in order to prevent the possibility of unsafe
9Python code being embedded within DExTer commands.
10"""
11
12import os
13import unittest
14from copy import copy
15from pathlib import PurePath
16from collections import defaultdict, OrderedDict
17
18from dex.utils.Exceptions import CommandParseError, NonFloatValueInCommand
19
20from dex.command.CommandBase import CommandBase
21from dex.command.commands.DexCommandLine import DexCommandLine
22from dex.command.commands.DexDeclareFile import DexDeclareFile
23from dex.command.commands.DexDeclareAddress import DexDeclareAddress
24from dex.command.commands.DexExpectProgramState import DexExpectProgramState
25from dex.command.commands.DexExpectStepKind import DexExpectStepKind
26from dex.command.commands.DexExpectStepOrder import DexExpectStepOrder
27from dex.command.commands.DexExpectWatchType import DexExpectWatchType
28from dex.command.commands.DexExpectWatchValue import DexExpectWatchValue
29from dex.command.commands.DexExpectWatchBase import AddressExpression, DexExpectWatchBase
30from dex.command.commands.DexLabel import DexLabel
31from dex.command.commands.DexLimitSteps import DexLimitSteps
32from dex.command.commands.DexFinishTest import DexFinishTest
33from dex.command.commands.DexUnreachable import DexUnreachable
34from dex.command.commands.DexWatch import DexWatch
35from dex.utils import Timer
36from dex.utils.Exceptions import CommandParseError, DebuggerException
37
38def _get_valid_commands():
39    """Return all top level DExTer test commands.
40
41    Returns:
42        { name (str): command (class) }
43    """
44    return {
45      DexCommandLine.get_name() : DexCommandLine,
46      DexDeclareAddress.get_name() : DexDeclareAddress,
47      DexDeclareFile.get_name() : DexDeclareFile,
48      DexExpectProgramState.get_name() : DexExpectProgramState,
49      DexExpectStepKind.get_name() : DexExpectStepKind,
50      DexExpectStepOrder.get_name() : DexExpectStepOrder,
51      DexExpectWatchType.get_name() : DexExpectWatchType,
52      DexExpectWatchValue.get_name() : DexExpectWatchValue,
53      DexLabel.get_name() : DexLabel,
54      DexLimitSteps.get_name() : DexLimitSteps,
55      DexFinishTest.get_name() : DexFinishTest,
56      DexUnreachable.get_name() : DexUnreachable,
57      DexWatch.get_name() : DexWatch
58    }
59
60
61def _get_command_name(command_raw: str) -> str:
62    """Return command name by splitting up DExTer command contained in
63    command_raw on the first opening paranthesis and further stripping
64    any potential leading or trailing whitespace.
65    """
66    return command_raw.split('(', 1)[0].rstrip()
67
68
69def _merge_subcommands(command_name: str, valid_commands: dict) -> dict:
70    """Merge valid_commands and command_name's subcommands into a new dict.
71
72    Returns:
73        { name (str): command (class) }
74    """
75    subcommands = valid_commands[command_name].get_subcommands()
76    if subcommands:
77        return { **valid_commands, **subcommands }
78    return valid_commands
79
80
81def _build_command(command_type, labels, addresses, raw_text: str, path: str, lineno: str) -> CommandBase:
82    """Build a command object from raw text.
83
84    This function will call eval().
85
86    Raises:
87        Any exception that eval() can raise.
88
89    Returns:
90        A dexter command object.
91    """
92    def label_to_line(label_name: str) -> int:
93        line = labels.get(label_name, None)
94        if line != None:
95            return line
96        raise format_unresolved_label_err(label_name, raw_text, path, lineno)
97
98    def get_address_object(address_name: str, offset: int=0):
99        if address_name not in addresses:
100            raise format_undeclared_address_err(address_name, raw_text, path, lineno)
101        return AddressExpression(address_name, offset)
102
103    valid_commands = _merge_subcommands(
104        command_type.get_name(), {
105            'ref': label_to_line,
106            'address': get_address_object,
107            command_type.get_name(): command_type,
108        })
109
110    # pylint: disable=eval-used
111    command = eval(raw_text, valid_commands)
112    # pylint: enable=eval-used
113    command.raw_text = raw_text
114    command.path = path
115    command.lineno = lineno
116    return command
117
118
119def _search_line_for_cmd_start(line: str, start: int, valid_commands: dict) -> int:
120    """Scan `line` for a string matching any key in `valid_commands`.
121
122    Start searching from `start`.
123    Commands escaped with `\` (E.g. `\DexLabel('a')`) are ignored.
124
125    Returns:
126        int: the index of the first character of the matching string in `line`
127        or -1 if no command is found.
128    """
129    for command in valid_commands:
130        idx = line.find(command, start)
131        if idx != -1:
132            # Ignore escaped '\' commands.
133            if idx > 0 and line[idx - 1] == '\\':
134                continue
135            return idx
136    return -1
137
138
139def _search_line_for_cmd_end(line: str, start: int, paren_balance: int) -> (int, int):
140    """Find the end of a command by looking for balanced parentheses.
141
142    Args:
143        line: String to scan.
144        start: Index into `line` to start looking.
145        paren_balance(int): paren_balance after previous call.
146
147    Note:
148        On the first call `start` should point at the opening parenthesis and
149        `paren_balance` should be set to 0. Subsequent calls should pass in the
150        returned `paren_balance`.
151
152    Returns:
153        ( end,  paren_balance )
154        Where end is 1 + the index of the last char in the command or, if the
155        parentheses are not balanced, the end of the line.
156
157        paren_balance will be 0 when the parentheses are balanced.
158    """
159    for end in range(start, len(line)):
160        ch = line[end]
161        if ch == '(':
162            paren_balance += 1
163        elif ch == ')':
164            paren_balance -=1
165        if paren_balance == 0:
166            break
167    end += 1
168    return (end, paren_balance)
169
170
171class TextPoint():
172    def __init__(self, line, char):
173        self.line = line
174        self.char = char
175
176    def get_lineno(self):
177        return self.line + 1
178
179    def get_column(self):
180        return self.char + 1
181
182
183def format_unresolved_label_err(label: str, src: str, filename: str, lineno) -> CommandParseError:
184    err = CommandParseError()
185    err.src = src
186    err.caret = '' # Don't bother trying to point to the bad label.
187    err.filename = filename
188    err.lineno = lineno
189    err.info = f'Unresolved label: \'{label}\''
190    return err
191
192def format_undeclared_address_err(address: str, src: str, filename: str, lineno) -> CommandParseError:
193    err = CommandParseError()
194    err.src = src
195    err.caret = '' # Don't bother trying to point to the bad address.
196    err.filename = filename
197    err.lineno = lineno
198    err.info = f'Undeclared address: \'{address}\''
199    return err
200
201def format_parse_err(msg: str, path: str, lines: list, point: TextPoint) -> CommandParseError:
202    err = CommandParseError()
203    err.filename = path
204    err.src = lines[point.line].rstrip()
205    err.lineno = point.get_lineno()
206    err.info = msg
207    err.caret = '{}<r>^</>'.format(' ' * (point.char))
208    return err
209
210
211def skip_horizontal_whitespace(line, point):
212    for idx, char in enumerate(line[point.char:]):
213        if char not in ' \t':
214            point.char += idx
215            return
216
217
218def add_line_label(labels, label, cmd_path, cmd_lineno):
219    # Enforce unique line labels.
220    if label.eval() in labels:
221        err = CommandParseError()
222        err.info = f'Found duplicate line label: \'{label.eval()}\''
223        err.lineno = cmd_lineno
224        err.filename = cmd_path
225        err.src = label.raw_text
226        # Don't both trying to point to it since we're only printing the raw
227        # command, which isn't much text.
228        err.caret = ''
229        raise err
230    labels[label.eval()] = label.get_line()
231
232def add_address(addresses, address, cmd_path, cmd_lineno):
233    # Enforce unique address variables.
234    address_name = address.get_address_name()
235    if address_name in addresses:
236        err = CommandParseError()
237        err.info = f'Found duplicate address: \'{address_name}\''
238        err.lineno = cmd_lineno
239        err.filename = cmd_path
240        err.src = address.raw_text
241        # Don't both trying to point to it since we're only printing the raw
242        # command, which isn't much text.
243        err.caret = ''
244        raise err
245    addresses.append(address_name)
246
247def _find_all_commands_in_file(path, file_lines, valid_commands, source_root_dir):
248    labels = {} # dict of {name: line}.
249    addresses = [] # list of addresses.
250    address_resolutions = {}
251    cmd_path = path
252    declared_files = set()
253    commands = defaultdict(dict)
254    paren_balance = 0
255    region_start = TextPoint(0, 0)
256
257    for region_start.line in range(len(file_lines)):
258        line = file_lines[region_start.line]
259        region_start.char = 0
260
261        # Search this line till we find no more commands.
262        while True:
263            # If parens are currently balanced we can look for a new command.
264            if paren_balance == 0:
265                region_start.char = _search_line_for_cmd_start(line, region_start.char, valid_commands)
266                if region_start.char == -1:
267                    break # Read next line.
268
269                command_name = _get_command_name(line[region_start.char:])
270                cmd_point = copy(region_start)
271                cmd_text_list = [command_name]
272
273                region_start.char += len(command_name) # Start searching for parens after cmd.
274                skip_horizontal_whitespace(line, region_start)
275                if region_start.char >= len(line) or line[region_start.char] != '(':
276                    raise format_parse_err(
277                        "Missing open parenthesis", path, file_lines, region_start)
278
279            end, paren_balance = _search_line_for_cmd_end(line, region_start.char, paren_balance)
280            # Add this text blob to the command.
281            cmd_text_list.append(line[region_start.char:end])
282            # Move parse ptr to end of line or parens.
283            region_start.char = end
284
285            # If the parens are unbalanced start reading the next line in an attempt
286            # to find the end of the command.
287            if paren_balance != 0:
288                break  # Read next line.
289
290            # Parens are balanced, we have a full command to evaluate.
291            raw_text = "".join(cmd_text_list)
292            try:
293                command = _build_command(
294                    valid_commands[command_name],
295                    labels,
296                    addresses,
297                    raw_text,
298                    cmd_path,
299                    cmd_point.get_lineno(),
300                )
301            except SyntaxError as e:
302                # This err should point to the problem line.
303                err_point = copy(cmd_point)
304                # To e the command start is the absolute start, so use as offset.
305                err_point.line += e.lineno - 1 # e.lineno is a position, not index.
306                err_point.char += e.offset - 1 # e.offset is a position, not index.
307                raise format_parse_err(e.msg, path, file_lines, err_point)
308            except TypeError as e:
309                # This err should always point to the end of the command name.
310                err_point = copy(cmd_point)
311                err_point.char += len(command_name)
312                raise format_parse_err(str(e), path, file_lines, err_point)
313            except NonFloatValueInCommand as e:
314                err_point = copy(cmd_point)
315                err_point.char += len(command_name)
316                raise format_parse_err(str(e), path, file_lines, err_point)
317            else:
318                if type(command) is DexLabel:
319                    add_line_label(labels, command, path, cmd_point.get_lineno())
320                elif type(command) is DexDeclareAddress:
321                    add_address(addresses, command, path, cmd_point.get_lineno())
322                elif type(command) is DexDeclareFile:
323                    cmd_path = command.declared_file
324                    if not os.path.isabs(cmd_path):
325                        source_dir = (source_root_dir if source_root_dir else
326                                      os.path.dirname(path))
327                        cmd_path = os.path.join(source_dir, cmd_path)
328                    # TODO: keep stored paths as PurePaths for 'longer'.
329                    cmd_path = str(PurePath(cmd_path))
330                    declared_files.add(cmd_path)
331                elif type(command) is DexCommandLine and 'DexCommandLine' in commands:
332                    msg = "More than one DexCommandLine in file"
333                    raise format_parse_err(msg, path, file_lines, err_point)
334
335                assert (path, cmd_point) not in commands[command_name], (
336                    command_name, commands[command_name])
337                commands[command_name][path, cmd_point] = command
338
339    if paren_balance != 0:
340        # This err should always point to the end of the command name.
341        err_point = copy(cmd_point)
342        err_point.char += len(command_name)
343        msg = "Unbalanced parenthesis starting here"
344        raise format_parse_err(msg, path, file_lines, err_point)
345    return dict(commands), declared_files
346
347def _find_all_commands(test_files, source_root_dir):
348    commands = defaultdict(dict)
349    valid_commands = _get_valid_commands()
350    new_source_files = set()
351    for test_file in test_files:
352        with open(test_file) as fp:
353            lines = fp.readlines()
354        file_commands, declared_files = _find_all_commands_in_file(
355            test_file, lines, valid_commands, source_root_dir)
356        for command_name in file_commands:
357            commands[command_name].update(file_commands[command_name])
358        new_source_files |= declared_files
359
360    return dict(commands), new_source_files
361
362def get_command_infos(test_files, source_root_dir):
363  with Timer('parsing commands'):
364      try:
365          commands, new_source_files = _find_all_commands(test_files, source_root_dir)
366          command_infos = OrderedDict()
367          for command_type in commands:
368              for command in commands[command_type].values():
369                  if command_type not in command_infos:
370                      command_infos[command_type] = []
371                  command_infos[command_type].append(command)
372          return OrderedDict(command_infos), new_source_files
373      except CommandParseError as e:
374          msg = 'parser error: <d>{}({}):</> {}\n{}\n{}\n'.format(
375                e.filename, e.lineno, e.info, e.src, e.caret)
376          raise DebuggerException(msg)
377
378class TestParseCommand(unittest.TestCase):
379    class MockCmd(CommandBase):
380        """A mock DExTer command for testing parsing.
381
382        Args:
383            value (str): Unique name for this instance.
384        """
385
386        def __init__(self, *args):
387           self.value = args[0]
388
389        def get_name():
390            return __class__.__name__
391
392        def eval(this):
393            pass
394
395
396    def __init__(self, *args):
397        super().__init__(*args)
398
399        self.valid_commands = {
400            TestParseCommand.MockCmd.get_name() : TestParseCommand.MockCmd
401        }
402
403
404    def _find_all_commands_in_lines(self, lines):
405        """Use DExTer parsing methods to find all the mock commands in lines.
406
407        Returns:
408            { cmd_name: { (path, line): command_obj } }
409        """
410        cmds, declared_files = _find_all_commands_in_file(__file__, lines, self.valid_commands, None)
411        return cmds
412
413
414    def _find_all_mock_values_in_lines(self, lines):
415        """Use DExTer parsing methods to find all mock command values in lines.
416
417        Returns:
418            values (list(str)): MockCmd values found in lines.
419        """
420        cmds = self._find_all_commands_in_lines(lines)
421        mocks = cmds.get(TestParseCommand.MockCmd.get_name(), None)
422        return [v.value for v in mocks.values()] if mocks else []
423
424
425    def test_parse_inline(self):
426        """Commands can be embedded in other text."""
427
428        lines = [
429            'MockCmd("START") Lorem ipsum dolor sit amet, consectetur\n',
430            'adipiscing elit, MockCmd("EMBEDDED") sed doeiusmod tempor,\n',
431            'incididunt ut labore et dolore magna aliqua.\n'
432        ]
433
434        values = self._find_all_mock_values_in_lines(lines)
435
436        self.assertTrue('START' in values)
437        self.assertTrue('EMBEDDED' in values)
438
439
440    def test_parse_multi_line_comment(self):
441        """Multi-line commands can embed comments."""
442
443        lines = [
444            'Lorem ipsum dolor sit amet, consectetur\n',
445            'adipiscing elit, sed doeiusmod tempor,\n',
446            'incididunt ut labore et MockCmd(\n',
447            '    "WITH_COMMENT" # THIS IS A COMMENT\n',
448            ') dolore magna aliqua. Ut enim ad minim\n',
449        ]
450
451        values = self._find_all_mock_values_in_lines(lines)
452
453        self.assertTrue('WITH_COMMENT' in values)
454
455    def test_parse_empty(self):
456        """Empty files are silently ignored."""
457
458        lines = []
459        values = self._find_all_mock_values_in_lines(lines)
460        self.assertTrue(len(values) == 0)
461
462    def test_parse_bad_whitespace(self):
463        """Throw exception when parsing badly formed whitespace."""
464        lines = [
465            'MockCmd\n',
466            '("XFAIL_CMD_LF_PAREN")\n',
467        ]
468
469        with self.assertRaises(CommandParseError):
470            values = self._find_all_mock_values_in_lines(lines)
471
472    def test_parse_good_whitespace(self):
473        """Try to emulate python whitespace rules"""
474
475        lines = [
476            'MockCmd("NONE")\n',
477            'MockCmd    ("SPACE")\n',
478            'MockCmd\t\t("TABS")\n',
479            'MockCmd(    "ARG_SPACE"    )\n',
480            'MockCmd(\t\t"ARG_TABS"\t\t)\n',
481            'MockCmd(\n',
482            '"CMD_PAREN_LF")\n',
483        ]
484
485        values = self._find_all_mock_values_in_lines(lines)
486
487        self.assertTrue('NONE' in values)
488        self.assertTrue('SPACE' in values)
489        self.assertTrue('TABS' in values)
490        self.assertTrue('ARG_SPACE' in values)
491        self.assertTrue('ARG_TABS' in values)
492        self.assertTrue('CMD_PAREN_LF' in values)
493
494
495    def test_parse_share_line(self):
496        """More than one command can appear on one line."""
497
498        lines = [
499            'MockCmd("START") MockCmd("CONSECUTIVE") words '
500                'MockCmd("EMBEDDED") more words\n'
501        ]
502
503        values = self._find_all_mock_values_in_lines(lines)
504
505        self.assertTrue('START' in values)
506        self.assertTrue('CONSECUTIVE' in values)
507        self.assertTrue('EMBEDDED' in values)
508
509
510    def test_parse_escaped(self):
511        """Escaped commands are ignored."""
512
513        lines = [
514            'words \MockCmd("IGNORED") words words words\n'
515        ]
516
517        values = self._find_all_mock_values_in_lines(lines)
518
519        self.assertFalse('IGNORED' in values)
520