//===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Math dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "../SPIRVCommon/Pattern.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "math-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts math.copysign to SPIR-V ops. class CopySignPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = getTypeConverter()->convertType(copySignOp.getType()); if (!type) return failure(); FloatType floatType; if (auto scalarType = copySignOp.getType().dyn_cast()) { floatType = scalarType; } else if (auto vectorType = copySignOp.getType().dyn_cast()) { floatType = vectorType.getElementType().cast(); } else { return failure(); } Location loc = copySignOp.getLoc(); int bitwidth = floatType.getWidth(); Type intType = rewriter.getIntegerType(bitwidth); Value signMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)))); Value valueMask = rewriter.create( loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u)); if (auto vectorType = copySignOp.getType().dyn_cast()) { assert(vectorType.getRank() == 1); int count = vectorType.getNumElements(); intType = VectorType::get(count, intType); SmallVector signSplat(count, signMask); signMask = rewriter.create(loc, intType, signSplat); SmallVector valueSplat(count, valueMask); valueMask = rewriter.create(loc, intType, valueSplat); } Value lhsCast = rewriter.create(loc, intType, adaptor.getLhs()); Value rhsCast = rewriter.create(loc, intType, adaptor.getRhs()); Value value = rewriter.create( loc, intType, ValueRange{lhsCast, valueMask}); Value sign = rewriter.create( loc, intType, ValueRange{rhsCast, signMask}); Value result = rewriter.create(loc, intType, ValueRange{value, sign}); rewriter.replaceOpWithNewOp(copySignOp, type, result); return success(); } }; /// Converts math.expm1 to SPIR-V ops. /// /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to /// these operations. template struct ExpM1OpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.getType()); auto exp = rewriter.create(loc, type, adaptor.getOperand()); auto one = spirv::ConstantOp::getOne(type, loc, rewriter); rewriter.replaceOpWithNewOp(operation, exp, one); return success(); } }; /// Converts math.log1p to SPIR-V ops. /// /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to /// these operations. template struct Log1pOpPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { assert(adaptor.getOperands().size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.getType()); auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); auto onePlus = rewriter.create(loc, one, adaptor.getOperand()); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { // Core patterns patterns.add(typeConverter, patterns.getContext()); // GLSL patterns patterns .add, ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); // OpenCL patterns patterns.add, ExpM1OpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern>( typeConverter, patterns.getContext()); } } // namespace mlir