1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from subprocess import Popen
6import os
7import subprocess
8import tempfile
9import traceback
10from ipykernel.kernelbase import Kernel
11
12__version__ = '0.0.1'
13
14
15def _get_executable():
16    """Find the mlir-opt executable."""
17
18    def is_exe(fpath):
19        """Returns whether executable file."""
20        return os.path.isfile(fpath) and os.access(fpath, os.X_OK)
21
22    program = os.environ.get('MLIR_OPT_EXECUTABLE', 'mlir-opt')
23    path, name = os.path.split(program)
24    # Attempt to get the executable
25    if path:
26        if is_exe(program):
27            return program
28    else:
29        for path in os.environ["PATH"].split(os.pathsep):
30            file = os.path.join(path, name)
31            if is_exe(file):
32                return file
33    raise OSError('mlir-opt not found, please see README')
34
35
36class MlirOptKernel(Kernel):
37    """Kernel using mlir-opt inside jupyter.
38
39    The reproducer syntax (`// configuration:`) is used to run passes. The
40    previous result can be referenced to by using `_` (this variable is reset
41    upon error). E.g.,
42
43    ```mlir
44    // configuration: --pass
45    func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... }
46    ```
47
48    ```mlir
49    // configuration: --next-pass
50    _
51    ```
52    """
53
54    implementation = 'mlir'
55    implementation_version = __version__
56
57    language_version = __version__
58    language = "mlir"
59    language_info = {
60        "name": "mlir",
61        "codemirror_mode": {
62            "name": "mlir"
63        },
64        "mimetype": "text/x-mlir",
65        "file_extension": ".mlir",
66        "pygments_lexer": "text"
67    }
68
69    @property
70    def banner(self):
71        """Returns kernel banner."""
72        # Just a placeholder.
73        return "mlir-opt kernel %s" % __version__
74
75    def __init__(self, **kwargs):
76        Kernel.__init__(self, **kwargs)
77        self._ = None
78        self.executable = None
79        self.silent = False
80
81    def get_executable(self):
82        """Returns the mlir-opt executable path."""
83        if not self.executable:
84            self.executable = _get_executable()
85        return self.executable
86
87    def process_output(self, output):
88        """Reports regular command output."""
89        if not self.silent:
90            # Send standard output
91            stream_content = {'name': 'stdout', 'text': output}
92            self.send_response(self.iopub_socket, 'stream', stream_content)
93
94    def process_error(self, output):
95        """Reports error response."""
96        if not self.silent:
97            # Send standard error
98            stream_content = {'name': 'stderr', 'text': output}
99            self.send_response(self.iopub_socket, 'stream', stream_content)
100
101    def do_execute(self,
102                   code,
103                   silent,
104                   store_history=True,
105                   user_expressions=None,
106                   allow_stdin=False):
107        """Execute user code using mlir-opt binary."""
108
109        def ok_status():
110            """Returns OK status."""
111            return {
112                'status': 'ok',
113                'execution_count': self.execution_count,
114                'payload': [],
115                'user_expressions': {}
116            }
117
118        def run(code):
119            """Run the code by pipeing via filesystem."""
120            try:
121                inputmlir = tempfile.NamedTemporaryFile(delete=False)
122                command = [
123                    # Specify input and output file to error out if also
124                    # set as arg.
125                    self.get_executable(),
126                    '--color',
127                    inputmlir.name,
128                    '-o',
129                    '-'
130                ]
131                # Simple handling of repeating last line.
132                if code.endswith('\n_'):
133                    if not self._:
134                        raise NameError('No previous result set')
135                    code = code[:-1] + self._
136                inputmlir.write(code.encode("utf-8"))
137                inputmlir.close()
138                pipe = Popen(command,
139                             stdout=subprocess.PIPE,
140                             stderr=subprocess.PIPE)
141                output, errors = pipe.communicate()
142                exitcode = pipe.returncode
143            finally:
144                os.unlink(inputmlir.name)
145
146# Replace temporary filename with placeholder. This takes the very
147# remote chance where the full input filename (generated above)
148# overlaps with something in the dump unrelated to the file.
149            fname = inputmlir.name.encode("utf-8")
150            output = output.replace(fname, b"<<input>>")
151            errors = errors.replace(fname, b"<<input>>")
152            return output, errors, exitcode
153
154        self.silent = silent
155        if not code.strip():
156            return ok_status()
157
158        try:
159            output, errors, exitcode = run(code)
160
161            if exitcode:
162                self._ = None
163            else:
164                self._ = output.decode("utf-8")
165        except KeyboardInterrupt:
166            return {'status': 'abort', 'execution_count': self.execution_count}
167        except Exception as error:
168            # Print traceback for local debugging.
169            traceback.print_exc()
170            self._ = None
171            exitcode = 255
172            errors = repr(error).encode("utf-8")
173
174        if exitcode:
175            content = {'ename': '', 'evalue': str(exitcode), 'traceback': []}
176
177            self.send_response(self.iopub_socket, 'error', content)
178            self.process_error(errors.decode("utf-8"))
179
180            content['execution_count'] = self.execution_count
181            content['status'] = 'error'
182            return content
183
184        if not silent:
185            data = {}
186            data['text/x-mlir'] = self._
187            content = {
188                'execution_count': self.execution_count,
189                'data': data,
190                'metadata': {}
191            }
192            self.send_response(self.iopub_socket, 'execute_result', content)
193            self.process_output(self._)
194            self.process_error(errors.decode("utf-8"))
195        return ok_status()
196