//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Math dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "math-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts math.expm1 to SPIR-V ops. /// /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to /// these operations. template class ExpM1OpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.getType()); auto exp = rewriter.create(loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp(operation, exp, one); return success(); } }; /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. template class Log1pOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.getType()); auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); auto onePlus = rewriter.create(loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // GLSL patterns patterns .add, ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); // OpenCL patterns patterns.add, ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); } } // namespace mlir