18b83b8f1SAart Bik# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 28b83b8f1SAart Bik# See https://llvm.org/LICENSE.txt for license information. 38b83b8f1SAart Bik# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 48b83b8f1SAart Bik 58b83b8f1SAart Bik# This file contains the sparse compiler class. 68b83b8f1SAart Bik 7*28063a28SAart Bikfrom mlir import execution_engine 88b83b8f1SAart Bikfrom mlir import ir 98b83b8f1SAart Bikfrom mlir import passmanager 10*28063a28SAart Bikfrom typing import Sequence 118b83b8f1SAart Bik 128b83b8f1SAart Bikclass SparseCompiler: 13*28063a28SAart Bik """Sparse compiler class for compiling and building MLIR modules.""" 148b83b8f1SAart Bik 15*28063a28SAart Bik def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]): 168b83b8f1SAart Bik pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}' 178b83b8f1SAart Bik self.pipeline = pipeline 18*28063a28SAart Bik self.opt_level = opt_level 19*28063a28SAart Bik self.shared_libs = shared_libs 208b83b8f1SAart Bik 218b83b8f1SAart Bik def __call__(self, module: ir.Module): 22*28063a28SAart Bik """Convenience application method.""" 23*28063a28SAart Bik self.compile(module) 24*28063a28SAart Bik 25*28063a28SAart Bik def compile(self, module: ir.Module): 26*28063a28SAart Bik """Compiles the module by invoking the sparse copmiler pipeline.""" 278b83b8f1SAart Bik passmanager.PassManager.parse(self.pipeline).run(module) 28*28063a28SAart Bik 29*28063a28SAart Bik def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine: 30*28063a28SAart Bik """Wraps the module in a JIT execution engine.""" 31*28063a28SAart Bik return execution_engine.ExecutionEngine( 32*28063a28SAart Bik module, opt_level=self.opt_level, shared_libs=self.shared_libs) 33*28063a28SAart Bik 34*28063a28SAart Bik def compile_and_jit(self, 35*28063a28SAart Bik module: ir.Module) -> execution_engine.ExecutionEngine: 36*28063a28SAart Bik """Compiles and jits the module.""" 37*28063a28SAart Bik self.compile(module) 38*28063a28SAart Bik return self.jit(module) 39