1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains the sparse compiler class. 6 7from mlir import all_passes_registration 8from mlir import execution_engine 9from mlir import ir 10from mlir import passmanager 11from typing import Sequence 12 13class SparseCompiler: 14 """Sparse compiler class for compiling and building MLIR modules.""" 15 16 def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): 17 pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}' 18 self.pipeline = pipeline 19 self.opt_level = opt_level 20 self.shared_libs = shared_libs 21 22 def __call__(self, module: ir.Module): 23 """Convenience application method.""" 24 self.compile(module) 25 26 def compile(self, module: ir.Module): 27 """Compiles the module by invoking the sparse copmiler pipeline.""" 28 passmanager.PassManager.parse(self.pipeline).run(module) 29 30 def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: 31 """Wraps the module in a JIT execution engine.""" 32 return execution_engine.ExecutionEngine( 33 module, opt_level=self.opt_level, shared_libs=self.shared_libs) 34 35 def compile_and_jit(self, 36 module: ir.Module) -> execution_engine.ExecutionEngine: 37 """Compiles and jits the module.""" 38 self.compile(module) 39 return self.jit(module) 40