1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Math/IR/Math.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "math-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Operation conversion 25 //===----------------------------------------------------------------------===// 26 27 // Note that DRR cannot be used for the patterns in this file: we may need to 28 // convert type along the way, which requires ConversionPattern. DRR generates 29 // normal RewritePattern. 30 31 namespace { 32 33 /// Converts unary and binary standard operations to SPIR-V operations. 34 template <typename StdOp, typename SPIRVOp> 35 class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> { 36 public: 37 using OpConversionPattern<StdOp>::OpConversionPattern; 38 39 LogicalResult 40 matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, 41 ConversionPatternRewriter &rewriter) const override { 42 assert(adaptor.getOperands().size() <= 2); 43 auto dstType = this->getTypeConverter()->convertType(operation.getType()); 44 if (!dstType) 45 return failure(); 46 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && 47 dstType != operation.getType()) { 48 return operation.emitError( 49 "bitwidth emulation is not implemented yet on unsigned op"); 50 } 51 rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, 52 adaptor.getOperands()); 53 return success(); 54 } 55 }; 56 57 /// Converts math.log1p to SPIR-V ops. 58 /// 59 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 60 /// these operations. 61 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 62 public: 63 using OpConversionPattern<math::Log1pOp>::OpConversionPattern; 64 65 LogicalResult 66 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 67 ConversionPatternRewriter &rewriter) const override { 68 assert(adaptor.getOperands().size() == 1); 69 Location loc = operation.getLoc(); 70 auto type = 71 this->getTypeConverter()->convertType(operation.operand().getType()); 72 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 73 auto onePlus = 74 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]); 75 rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus); 76 return success(); 77 } 78 }; 79 80 } // namespace 81 82 //===----------------------------------------------------------------------===// 83 // Pattern population 84 //===----------------------------------------------------------------------===// 85 86 namespace mlir { 87 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 88 RewritePatternSet &patterns) { 89 patterns.add<Log1pOpPattern, 90 UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>, 91 UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>, 92 UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>, 93 UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 94 UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>, 95 UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>, 96 UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 97 UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 98 typeConverter, patterns.getContext()); 99 } 100 101 } // namespace mlir 102