18b83b8f1SAart Bik#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
28b83b8f1SAart Bik#  See https://llvm.org/LICENSE.txt for license information.
38b83b8f1SAart Bik#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
48b83b8f1SAart Bik
58b83b8f1SAart Bik#  This file contains the sparse compiler class.
68b83b8f1SAart Bik
7*28063a28SAart Bikfrom mlir import execution_engine
88b83b8f1SAart Bikfrom mlir import ir
98b83b8f1SAart Bikfrom mlir import passmanager
10*28063a28SAart Bikfrom typing import Sequence
118b83b8f1SAart Bik
128b83b8f1SAart Bikclass SparseCompiler:
13*28063a28SAart Bik  """Sparse compiler class for compiling and building MLIR modules."""
148b83b8f1SAart Bik
15*28063a28SAart Bik  def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
168b83b8f1SAart Bik    pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
178b83b8f1SAart Bik    self.pipeline = pipeline
18*28063a28SAart Bik    self.opt_level = opt_level
19*28063a28SAart Bik    self.shared_libs = shared_libs
208b83b8f1SAart Bik
218b83b8f1SAart Bik  def __call__(self, module: ir.Module):
22*28063a28SAart Bik    """Convenience application method."""
23*28063a28SAart Bik    self.compile(module)
24*28063a28SAart Bik
25*28063a28SAart Bik  def compile(self, module: ir.Module):
26*28063a28SAart Bik    """Compiles the module by invoking the sparse copmiler pipeline."""
278b83b8f1SAart Bik    passmanager.PassManager.parse(self.pipeline).run(module)
28*28063a28SAart Bik
29*28063a28SAart Bik  def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
30*28063a28SAart Bik    """Wraps the module in a JIT execution engine."""
31*28063a28SAart Bik    return execution_engine.ExecutionEngine(
32*28063a28SAart Bik        module, opt_level=self.opt_level, shared_libs=self.shared_libs)
33*28063a28SAart Bik
34*28063a28SAart Bik  def compile_and_jit(self,
35*28063a28SAart Bik                      module: ir.Module) -> execution_engine.ExecutionEngine:
36*28063a28SAart Bik    """Compiles and jits the module."""
37*28063a28SAart Bik    self.compile(module)
38*28063a28SAart Bik    return self.jit(module)
39