//===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Vector dialect to SPIRV dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" #include "../PassDetail.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" #include using namespace mlir; /// Gets the first integer value from `attr`, assuming it is an integer array /// attribute. static uint64_t getFirstIntValue(ArrayAttr attr) { return (*attr.getAsValueRange().begin()).getZExtValue(); } namespace { struct VectorBitcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); if (!dstType) return failure(); if (dstType == adaptor.source().getType()) rewriter.replaceOp(bitcastOp, adaptor.source()); else rewriter.replaceOpWithNewOp(bitcastOp, dstType, adaptor.source()); return success(); } }; struct VectorBroadcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (broadcastOp.source().getType().isa() || !spirv::CompositeType::isValid(broadcastOp.getVectorType())) return failure(); SmallVector source(broadcastOp.getVectorType().getNumElements(), adaptor.source()); rewriter.replaceOpWithNewOp( broadcastOp, broadcastOp.getVectorType(), source); return success(); } }; struct VectorExtractOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only support extracting a scalar value now. VectorType resultVectorType = extractOp.getType().dyn_cast(); if (resultVectorType && resultVectorType.getNumElements() > 1) return failure(); auto dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); if (adaptor.vector().getType().isa()) { rewriter.replaceOp(extractOp, adaptor.vector()); return success(); } int32_t id = getFirstIntValue(extractOp.position()); rewriter.replaceOpWithNewOp( extractOp, adaptor.vector(), id); return success(); } }; struct VectorExtractStridedSliceOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) return failure(); uint64_t offset = getFirstIntValue(extractOp.offsets()); uint64_t size = getFirstIntValue(extractOp.sizes()); uint64_t stride = getFirstIntValue(extractOp.strides()); if (stride != 1) return failure(); Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. if (dstType.isa()) { rewriter.replaceOpWithNewOp(extractOp, srcVector, offset); return success(); } SmallVector indices(size); std::iota(indices.begin(), indices.end(), offset); rewriter.replaceOpWithNewOp( extractOp, dstType, srcVector, srcVector, rewriter.getI32ArrayAttr(indices)); return success(); } }; struct VectorFmaOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); } }; struct VectorInsertOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); int32_t id = getFirstIntValue(insertOp.position()); rewriter.replaceOpWithNewOp( insertOp, adaptor.source(), adaptor.dest(), id); return success(); } }; struct VectorExtractElementOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); rewriter.replaceOpWithNewOp( extractElementOp, extractElementOp.getType(), adaptor.vector(), extractElementOp.position()); return success(); } }; struct VectorInsertElementOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); rewriter.replaceOpWithNewOp( insertElementOp, insertElementOp.getType(), insertElementOp.dest(), adaptor.source(), insertElementOp.position()); return success(); } }; struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value srcVector = adaptor.getOperands().front(); Value dstVector = adaptor.getOperands().back(); // Insert scalar values not supported yet. if (srcVector.getType().isa() || dstVector.getType().isa()) return failure(); uint64_t stride = getFirstIntValue(insertOp.strides()); if (stride != 1) return failure(); uint64_t totalSize = dstVector.getType().cast().getNumElements(); uint64_t insertSize = srcVector.getType().cast().getNumElements(); uint64_t offset = getFirstIntValue(insertOp.offsets()); SmallVector indices(totalSize); std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, totalSize); rewriter.replaceOpWithNewOp( insertOp, dstVector.getType(), dstVector, srcVector, rewriter.getI32ArrayAttr(indices)); return success(); } }; } // namespace void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); }