1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#  This file contains the sparse compiler class.
6
7from mlir import all_passes_registration
8from mlir import execution_engine
9from mlir import ir
10from mlir import passmanager
11from typing import Sequence
12
13class SparseCompiler:
14  """Sparse compiler class for compiling and building MLIR modules."""
15
16  def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
17    pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
18    self.pipeline = pipeline
19    self.opt_level = opt_level
20    self.shared_libs = shared_libs
21
22  def __call__(self, module: ir.Module):
23    """Convenience application method."""
24    self.compile(module)
25
26  def compile(self, module: ir.Module):
27    """Compiles the module by invoking the sparse copmiler pipeline."""
28    passmanager.PassManager.parse(self.pipeline).run(module)
29
30  def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
31    """Wraps the module in a JIT execution engine."""
32    return execution_engine.ExecutionEngine(
33        module, opt_level=self.opt_level, shared_libs=self.shared_libs)
34
35  def compile_and_jit(self,
36                      module: ir.Module) -> execution_engine.ExecutionEngine:
37    """Compiles and jits the module."""
38    self.compile(module)
39    return self.jit(module)
40