1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Math/IR/Math.h" 14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 17 #include "llvm/Support/Debug.h" 18 19 #define DEBUG_TYPE "math-to-spirv-pattern" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Operation conversion 25 //===----------------------------------------------------------------------===// 26 27 // Note that DRR cannot be used for the patterns in this file: we may need to 28 // convert type along the way, which requires ConversionPattern. DRR generates 29 // normal RewritePattern. 30 31 namespace { 32 33 /// Converts unary and binary standard operations to SPIR-V operations. 34 template <typename StdOp, typename SPIRVOp> 35 class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> { 36 public: 37 using OpConversionPattern<StdOp>::OpConversionPattern; 38 39 LogicalResult 40 matchAndRewrite(StdOp operation, ArrayRef<Value> operands, 41 ConversionPatternRewriter &rewriter) const override { 42 assert(operands.size() <= 2); 43 auto dstType = this->getTypeConverter()->convertType(operation.getType()); 44 if (!dstType) 45 return failure(); 46 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && 47 dstType != operation.getType()) { 48 return operation.emitError( 49 "bitwidth emulation is not implemented yet on unsigned op"); 50 } 51 rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands); 52 return success(); 53 } 54 }; 55 56 /// Converts math.log1p to SPIR-V ops. 57 /// 58 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 59 /// these operations. 60 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 61 public: 62 using OpConversionPattern<math::Log1pOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands, 66 ConversionPatternRewriter &rewriter) const override { 67 assert(operands.size() == 1); 68 Location loc = operation.getLoc(); 69 auto type = 70 this->getTypeConverter()->convertType(operation.operand().getType()); 71 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 72 auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]); 73 rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus); 74 return success(); 75 } 76 }; 77 78 } // namespace 79 80 //===----------------------------------------------------------------------===// 81 // Pattern population 82 //===----------------------------------------------------------------------===// 83 84 namespace mlir { 85 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 86 RewritePatternSet &patterns) { 87 patterns.add<Log1pOpPattern, 88 UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>, 89 UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>, 90 UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>, 91 UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 92 UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>, 93 UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>, 94 UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 95 UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 96 typeConverter, patterns.getContext()); 97 } 98 99 } // namespace mlir 100