1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4 5# This file contains the sparse compiler class. 6 7from mlir import execution_engine 8from mlir import ir 9from mlir import passmanager 10from typing import Sequence 11 12class SparseCompiler: 13 """Sparse compiler class for compiling and building MLIR modules.""" 14 15 def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): 16 pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}' 17 self.pipeline = pipeline 18 self.opt_level = opt_level 19 self.shared_libs = shared_libs 20 21 def __call__(self, module: ir.Module): 22 """Convenience application method.""" 23 self.compile(module) 24 25 def compile(self, module: ir.Module): 26 """Compiles the module by invoking the sparse copmiler pipeline.""" 27 passmanager.PassManager.parse(self.pipeline).run(module) 28 29 def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: 30 """Wraps the module in a JIT execution engine.""" 31 return execution_engine.ExecutionEngine( 32 module, opt_level=self.opt_level, shared_libs=self.shared_libs) 33 34 def compile_and_jit(self, 35 module: ir.Module) -> execution_engine.ExecutionEngine: 36 """Compiles and jits the module.""" 37 self.compile(module) 38 return self.jit(module) 39