13fef2d26SRiver Riddle //===- TestPolynomialApproximation.cpp - Test math ops approximations -----===//
23fef2d26SRiver Riddle //
33fef2d26SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
43fef2d26SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
53fef2d26SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
63fef2d26SRiver Riddle //
73fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
83fef2d26SRiver Riddle //
93fef2d26SRiver Riddle // This file contains test passes for expanding math operations into
103fef2d26SRiver Riddle // polynomial approximations.
113fef2d26SRiver Riddle //
123fef2d26SRiver Riddle //===----------------------------------------------------------------------===//
133fef2d26SRiver Riddle 
14a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
153fef2d26SRiver Riddle #include "mlir/Dialect/Math/IR/Math.h"
163fef2d26SRiver Riddle #include "mlir/Dialect/Math/Transforms/Passes.h"
1799ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
1835553d45SEmilio Cota #include "mlir/Dialect/X86Vector/X86VectorDialect.h"
193fef2d26SRiver Riddle #include "mlir/Pass/Pass.h"
203fef2d26SRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
213fef2d26SRiver Riddle 
223fef2d26SRiver Riddle using namespace mlir;
233fef2d26SRiver Riddle 
243fef2d26SRiver Riddle namespace {
253fef2d26SRiver Riddle struct TestMathPolynomialApproximationPass
2687d6bf37SRiver Riddle     : public PassWrapper<TestMathPolynomialApproximationPass, OperationPass<>> {
27*5e50dd04SRiver Riddle   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
28*5e50dd04SRiver Riddle       TestMathPolynomialApproximationPass)
29*5e50dd04SRiver Riddle 
3035553d45SEmilio Cota   TestMathPolynomialApproximationPass() = default;
TestMathPolynomialApproximationPass__anond2ee01f50111::TestMathPolynomialApproximationPass3135553d45SEmilio Cota   TestMathPolynomialApproximationPass(
323bab9d4eSMehdi Amini       const TestMathPolynomialApproximationPass &pass)
333bab9d4eSMehdi Amini       : PassWrapper(pass) {}
3435553d45SEmilio Cota 
3541574554SRiver Riddle   void runOnOperation() override;
getDependentDialects__anond2ee01f50111::TestMathPolynomialApproximationPass363fef2d26SRiver Riddle   void getDependentDialects(DialectRegistry &registry) const override {
37a54f4eaeSMogball     registry.insert<arith::ArithmeticDialect, math::MathDialect,
38a54f4eaeSMogball                     vector::VectorDialect>();
3935553d45SEmilio Cota     if (enableAvx2)
4035553d45SEmilio Cota       registry.insert<x86vector::X86VectorDialect>();
413fef2d26SRiver Riddle   }
getArgument__anond2ee01f50111::TestMathPolynomialApproximationPass42b5e22e6dSMehdi Amini   StringRef getArgument() const final {
43b5e22e6dSMehdi Amini     return "test-math-polynomial-approximation";
44b5e22e6dSMehdi Amini   }
getDescription__anond2ee01f50111::TestMathPolynomialApproximationPass45b5e22e6dSMehdi Amini   StringRef getDescription() const final {
46b5e22e6dSMehdi Amini     return "Test math polynomial approximations";
47b5e22e6dSMehdi Amini   }
4835553d45SEmilio Cota 
4935553d45SEmilio Cota   Option<bool> enableAvx2{
5035553d45SEmilio Cota       *this, "enable-avx2",
5135553d45SEmilio Cota       llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
5235553d45SEmilio Cota                      "X86Vector dialect"),
5335553d45SEmilio Cota       llvm::cl::init(false)};
543fef2d26SRiver Riddle };
55be0a7e9fSMehdi Amini } // namespace
563fef2d26SRiver Riddle 
runOnOperation()5741574554SRiver Riddle void TestMathPolynomialApproximationPass::runOnOperation() {
583fef2d26SRiver Riddle   RewritePatternSet patterns(&getContext());
5902b6fb21SMehdi Amini   MathPolynomialApproximationOptions approxOptions;
6002b6fb21SMehdi Amini   approxOptions.enableAvx2 = enableAvx2;
6102b6fb21SMehdi Amini   populateMathPolynomialApproximationPatterns(patterns, approxOptions);
623fef2d26SRiver Riddle   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
633fef2d26SRiver Riddle }
643fef2d26SRiver Riddle 
653fef2d26SRiver Riddle namespace mlir {
663fef2d26SRiver Riddle namespace test {
registerTestMathPolynomialApproximationPass()673fef2d26SRiver Riddle void registerTestMathPolynomialApproximationPass() {
68b5e22e6dSMehdi Amini   PassRegistration<TestMathPolynomialApproximationPass>();
693fef2d26SRiver Riddle }
703fef2d26SRiver Riddle } // namespace test
713fef2d26SRiver Riddle } // namespace mlir
72