1b85ed4e0Swren romano //===- SparseTensorPipelines.cpp - Pipelines for sparse tensor code -------===//
2b85ed4e0Swren romano //
3b85ed4e0Swren romano // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4b85ed4e0Swren romano // See https://llvm.org/LICENSE.txt for license information.
5b85ed4e0Swren romano // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6b85ed4e0Swren romano //
7b85ed4e0Swren romano //===----------------------------------------------------------------------===//
8b85ed4e0Swren romano 
9b85ed4e0Swren romano #include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
10b85ed4e0Swren romano 
11b85ed4e0Swren romano #include "mlir/Conversion/Passes.h"
12*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
13*c66303c2SMatthias Springer #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14b85ed4e0Swren romano #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1536550692SRiver Riddle #include "mlir/Dialect/Func/IR/FuncOps.h"
16b85ed4e0Swren romano #include "mlir/Dialect/Linalg/Passes.h"
17b85ed4e0Swren romano #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
18b85ed4e0Swren romano #include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
19b85ed4e0Swren romano #include "mlir/Pass/PassManager.h"
20b85ed4e0Swren romano 
21b85ed4e0Swren romano using namespace mlir;
22b85ed4e0Swren romano using namespace mlir::sparse_tensor;
23b85ed4e0Swren romano 
24*c66303c2SMatthias Springer /// Return configuration options for One-Shot Bufferize.
25*c66303c2SMatthias Springer static bufferization::OneShotBufferizationOptions
getBufferizationOptions(bool analysisOnly)26*c66303c2SMatthias Springer getBufferizationOptions(bool analysisOnly) {
27*c66303c2SMatthias Springer   using namespace bufferization;
28*c66303c2SMatthias Springer   OneShotBufferizationOptions options;
29*c66303c2SMatthias Springer   options.bufferizeFunctionBoundaries = true;
30*c66303c2SMatthias Springer   // TODO(springerm): To spot memory leaks more easily, returning dense allocs
31*c66303c2SMatthias Springer   // should be disallowed.
32*c66303c2SMatthias Springer   options.allowReturnAllocs = true;
33*c66303c2SMatthias Springer   options.functionBoundaryTypeConversion =
34*c66303c2SMatthias Springer       BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
35*c66303c2SMatthias Springer   options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
36*c66303c2SMatthias Springer                                       const BufferizationOptions &options) {
37*c66303c2SMatthias Springer     return getMemRefTypeWithStaticIdentityLayout(
38*c66303c2SMatthias Springer         value.getType().cast<TensorType>(), memorySpace);
39*c66303c2SMatthias Springer   };
40*c66303c2SMatthias Springer   if (analysisOnly) {
41*c66303c2SMatthias Springer     options.testAnalysisOnly = true;
42*c66303c2SMatthias Springer     options.printConflicts = true;
43*c66303c2SMatthias Springer   }
44*c66303c2SMatthias Springer   return options;
45*c66303c2SMatthias Springer }
46*c66303c2SMatthias Springer 
47b85ed4e0Swren romano //===----------------------------------------------------------------------===//
48b85ed4e0Swren romano // Pipeline implementation.
49b85ed4e0Swren romano //===----------------------------------------------------------------------===//
50b85ed4e0Swren romano 
buildSparseCompiler(OpPassManager & pm,const SparseCompilerOptions & options)51b85ed4e0Swren romano void mlir::sparse_tensor::buildSparseCompiler(
52b85ed4e0Swren romano     OpPassManager &pm, const SparseCompilerOptions &options) {
534998b1a6Swren romano   // TODO(wrengr): ensure the original `pm` is for ModuleOp
5458ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createLinalgGeneralizationPass());
55*c66303c2SMatthias Springer   pm.addPass(
56*c66303c2SMatthias Springer       bufferization::createTensorCopyInsertionPass(getBufferizationOptions(
57*c66303c2SMatthias Springer           /*analysisOnly=*/options.testBufferizationAnalysisOnly)));
58*c66303c2SMatthias Springer   if (options.testBufferizationAnalysisOnly)
59*c66303c2SMatthias Springer     return;
60b85ed4e0Swren romano   pm.addPass(createSparsificationPass(options.sparsificationOptions()));
61c7e24db4Swren romano   pm.addPass(createSparseTensorConversionPass(
62c7e24db4Swren romano       options.sparseTensorConversionOptions()));
63*c66303c2SMatthias Springer   pm.addPass(createDenseBufferizationPass(
64*c66303c2SMatthias Springer       getBufferizationOptions(/*analysisOnly=*/false)));
65*c66303c2SMatthias Springer   pm.addNestedPass<func::FuncOp>(
66*c66303c2SMatthias Springer       mlir::bufferization::createFinalizingBufferizePass());
67*c66303c2SMatthias Springer   // TODO(springerm): Add sparse support to the BufferDeallocation pass and add
68*c66303c2SMatthias Springer   // it to this pipeline.
6958ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
7058ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertVectorToSCFPass());
7158ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertSCFToCFPass());
72b85ed4e0Swren romano   pm.addPass(createLowerAffinePass());
734998b1a6Swren romano   pm.addPass(createConvertVectorToLLVMPass(options.lowerVectorToLLVMOptions()));
74b85ed4e0Swren romano   pm.addPass(createMemRefToLLVMPass());
75736c1b66SAart Bik   pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
7658ceae95SRiver Riddle   pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
775b122a73SAart Bik   pm.addPass(createConvertMathToLibmPass());
78a9e354c8SAart Bik   pm.addPass(createConvertComplexToLibmPass());
79a9e354c8SAart Bik   pm.addPass(createConvertComplexToLLVMPass());
805a7b9194SRiver Riddle   pm.addPass(createConvertFuncToLLVMPass());
81b85ed4e0Swren romano   pm.addPass(createReconcileUnrealizedCastsPass());
82b85ed4e0Swren romano }
83b85ed4e0Swren romano 
84b85ed4e0Swren romano //===----------------------------------------------------------------------===//
85b85ed4e0Swren romano // Pipeline registration.
86b85ed4e0Swren romano //===----------------------------------------------------------------------===//
87b85ed4e0Swren romano 
registerSparseTensorPipelines()88b85ed4e0Swren romano void mlir::sparse_tensor::registerSparseTensorPipelines() {
89b85ed4e0Swren romano   PassPipelineRegistration<SparseCompilerOptions>(
90b85ed4e0Swren romano       "sparse-compiler",
91b85ed4e0Swren romano       "The standard pipeline for taking sparsity-agnostic IR using the"
92b85ed4e0Swren romano       " sparse-tensor type, and lowering it to LLVM IR with concrete"
93b85ed4e0Swren romano       " representations and algorithms for sparse tensors.",
94b85ed4e0Swren romano       buildSparseCompiler);
95b85ed4e0Swren romano }
96