1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5from subprocess import Popen 6import os 7import subprocess 8import tempfile 9import traceback 10from ipykernel.kernelbase import Kernel 11 12__version__ = '0.0.1' 13 14 15def _get_executable(): 16 """Find the mlir-opt executable.""" 17 18 def is_exe(fpath): 19 """Returns whether executable file.""" 20 return os.path.isfile(fpath) and os.access(fpath, os.X_OK) 21 22 program = os.environ.get('MLIR_OPT_EXECUTABLE', 'mlir-opt') 23 path, name = os.path.split(program) 24 # Attempt to get the executable 25 if path: 26 if is_exe(program): 27 return program 28 else: 29 for path in os.environ["PATH"].split(os.pathsep): 30 file = os.path.join(path, name) 31 if is_exe(file): 32 return file 33 raise OSError('mlir-opt not found, please see README') 34 35 36class MlirOptKernel(Kernel): 37 """Kernel using mlir-opt inside jupyter. 38 39 The reproducer syntax (`// configuration:`) is used to run passes. The 40 previous result can be referenced to by using `_` (this variable is reset 41 upon error). E.g., 42 43 ```mlir 44 // configuration: --pass 45 func.func @foo(%tensor: tensor<2x3xf64>) -> tensor<3x2xf64> { ... } 46 ``` 47 48 ```mlir 49 // configuration: --next-pass 50 _ 51 ``` 52 """ 53 54 implementation = 'mlir' 55 implementation_version = __version__ 56 57 language_version = __version__ 58 language = "mlir" 59 language_info = { 60 "name": "mlir", 61 "codemirror_mode": { 62 "name": "mlir" 63 }, 64 "mimetype": "text/x-mlir", 65 "file_extension": ".mlir", 66 "pygments_lexer": "text" 67 } 68 69 @property 70 def banner(self): 71 """Returns kernel banner.""" 72 # Just a placeholder. 73 return "mlir-opt kernel %s" % __version__ 74 75 def __init__(self, **kwargs): 76 Kernel.__init__(self, **kwargs) 77 self._ = None 78 self.executable = None 79 self.silent = False 80 81 def get_executable(self): 82 """Returns the mlir-opt executable path.""" 83 if not self.executable: 84 self.executable = _get_executable() 85 return self.executable 86 87 def process_output(self, output): 88 """Reports regular command output.""" 89 if not self.silent: 90 # Send standard output 91 stream_content = {'name': 'stdout', 'text': output} 92 self.send_response(self.iopub_socket, 'stream', stream_content) 93 94 def process_error(self, output): 95 """Reports error response.""" 96 if not self.silent: 97 # Send standard error 98 stream_content = {'name': 'stderr', 'text': output} 99 self.send_response(self.iopub_socket, 'stream', stream_content) 100 101 def do_execute(self, 102 code, 103 silent, 104 store_history=True, 105 user_expressions=None, 106 allow_stdin=False): 107 """Execute user code using mlir-opt binary.""" 108 109 def ok_status(): 110 """Returns OK status.""" 111 return { 112 'status': 'ok', 113 'execution_count': self.execution_count, 114 'payload': [], 115 'user_expressions': {} 116 } 117 118 def run(code): 119 """Run the code by pipeing via filesystem.""" 120 try: 121 inputmlir = tempfile.NamedTemporaryFile(delete=False) 122 command = [ 123 # Specify input and output file to error out if also 124 # set as arg. 125 self.get_executable(), 126 '--color', 127 inputmlir.name, 128 '-o', 129 '-' 130 ] 131 # Simple handling of repeating last line. 132 if code.endswith('\n_'): 133 if not self._: 134 raise NameError('No previous result set') 135 code = code[:-1] + self._ 136 inputmlir.write(code.encode("utf-8")) 137 inputmlir.close() 138 pipe = Popen(command, 139 stdout=subprocess.PIPE, 140 stderr=subprocess.PIPE) 141 output, errors = pipe.communicate() 142 exitcode = pipe.returncode 143 finally: 144 os.unlink(inputmlir.name) 145 146# Replace temporary filename with placeholder. This takes the very 147# remote chance where the full input filename (generated above) 148# overlaps with something in the dump unrelated to the file. 149 fname = inputmlir.name.encode("utf-8") 150 output = output.replace(fname, b"<<input>>") 151 errors = errors.replace(fname, b"<<input>>") 152 return output, errors, exitcode 153 154 self.silent = silent 155 if not code.strip(): 156 return ok_status() 157 158 try: 159 output, errors, exitcode = run(code) 160 161 if exitcode: 162 self._ = None 163 else: 164 self._ = output.decode("utf-8") 165 except KeyboardInterrupt: 166 return {'status': 'abort', 'execution_count': self.execution_count} 167 except Exception as error: 168 # Print traceback for local debugging. 169 traceback.print_exc() 170 self._ = None 171 exitcode = 255 172 errors = repr(error).encode("utf-8") 173 174 if exitcode: 175 content = {'ename': '', 'evalue': str(exitcode), 'traceback': []} 176 177 self.send_response(self.iopub_socket, 'error', content) 178 self.process_error(errors.decode("utf-8")) 179 180 content['execution_count'] = self.execution_count 181 content['status'] = 'error' 182 return content 183 184 if not silent: 185 data = {} 186 data['text/x-mlir'] = self._ 187 content = { 188 'execution_count': self.execution_count, 189 'data': data, 190 'metadata': {} 191 } 192 self.send_response(self.iopub_socket, 'execute_result', content) 193 self.process_output(self._) 194 self.process_error(errors.decode("utf-8")) 195 return ok_status() 196