1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5#  This file contains the sparse compiler class.
6
7from mlir import all_passes_registration
8from mlir import ir
9from mlir import passmanager
10
11class SparseCompiler:
12  """Sparse compiler definition."""
13
14  def __init__(self, options: str):
15    pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
16    self.pipeline = pipeline
17
18  def __call__(self, module: ir.Module):
19    passmanager.PassManager.parse(self.pipeline).run(module)
20