1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include <numeric> 23 24 using namespace mlir; 25 26 /// Gets the first integer value from `attr`, assuming it is an integer array 27 /// attribute. 28 static uint64_t getFirstIntValue(ArrayAttr attr) { 29 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 30 } 31 32 namespace { 33 34 struct VectorBitcastConvert final 35 : public OpConversionPattern<vector::BitCastOp> { 36 using OpConversionPattern::OpConversionPattern; 37 38 LogicalResult 39 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, 40 ConversionPatternRewriter &rewriter) const override { 41 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 42 if (!dstType) 43 return failure(); 44 45 if (dstType == adaptor.source().getType()) 46 rewriter.replaceOp(bitcastOp, adaptor.source()); 47 else 48 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 49 adaptor.source()); 50 51 return success(); 52 } 53 }; 54 55 struct VectorBroadcastConvert final 56 : public OpConversionPattern<vector::BroadcastOp> { 57 using OpConversionPattern::OpConversionPattern; 58 59 LogicalResult 60 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, 61 ConversionPatternRewriter &rewriter) const override { 62 if (broadcastOp.source().getType().isa<VectorType>() || 63 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 64 return failure(); 65 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 66 adaptor.source()); 67 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 68 broadcastOp, broadcastOp.getVectorType(), source); 69 return success(); 70 } 71 }; 72 73 struct VectorExtractOpConvert final 74 : public OpConversionPattern<vector::ExtractOp> { 75 using OpConversionPattern::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override { 80 // Only support extracting a scalar value now. 81 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 82 if (resultVectorType && resultVectorType.getNumElements() > 1) 83 return failure(); 84 85 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 86 if (!dstType) 87 return failure(); 88 89 if (adaptor.vector().getType().isa<spirv::ScalarType>()) { 90 rewriter.replaceOp(extractOp, adaptor.vector()); 91 return success(); 92 } 93 94 int32_t id = getFirstIntValue(extractOp.position()); 95 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 96 extractOp, adaptor.vector(), id); 97 return success(); 98 } 99 }; 100 101 struct VectorExtractStridedSliceOpConvert final 102 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 103 using OpConversionPattern::OpConversionPattern; 104 105 LogicalResult 106 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 107 ConversionPatternRewriter &rewriter) const override { 108 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 109 if (!dstType) 110 return failure(); 111 112 113 uint64_t offset = getFirstIntValue(extractOp.offsets()); 114 uint64_t size = getFirstIntValue(extractOp.sizes()); 115 uint64_t stride = getFirstIntValue(extractOp.strides()); 116 if (stride != 1) 117 return failure(); 118 119 Value srcVector = adaptor.getOperands().front(); 120 121 // Extract vector<1xT> case. 122 if (dstType.isa<spirv::ScalarType>()) { 123 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 124 srcVector, offset); 125 return success(); 126 } 127 128 SmallVector<int32_t, 2> indices(size); 129 std::iota(indices.begin(), indices.end(), offset); 130 131 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 132 extractOp, dstType, srcVector, srcVector, 133 rewriter.getI32ArrayAttr(indices)); 134 135 return success(); 136 } 137 }; 138 139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 140 using OpConversionPattern::OpConversionPattern; 141 142 LogicalResult 143 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 144 ConversionPatternRewriter &rewriter) const override { 145 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 146 return failure(); 147 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 148 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 149 return success(); 150 } 151 }; 152 153 struct VectorInsertOpConvert final 154 : public OpConversionPattern<vector::InsertOp> { 155 using OpConversionPattern::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override { 160 if (insertOp.getSourceType().isa<VectorType>() || 161 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 162 return failure(); 163 int32_t id = getFirstIntValue(insertOp.position()); 164 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 165 insertOp, adaptor.source(), adaptor.dest(), id); 166 return success(); 167 } 168 }; 169 170 struct VectorExtractElementOpConvert final 171 : public OpConversionPattern<vector::ExtractElementOp> { 172 using OpConversionPattern::OpConversionPattern; 173 174 LogicalResult 175 matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, 176 ConversionPatternRewriter &rewriter) const override { 177 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 178 return failure(); 179 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 180 extractElementOp, extractElementOp.getType(), adaptor.vector(), 181 extractElementOp.position()); 182 return success(); 183 } 184 }; 185 186 struct VectorInsertElementOpConvert final 187 : public OpConversionPattern<vector::InsertElementOp> { 188 using OpConversionPattern::OpConversionPattern; 189 190 LogicalResult 191 matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, 192 ConversionPatternRewriter &rewriter) const override { 193 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 194 return failure(); 195 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 196 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 197 adaptor.source(), insertElementOp.position()); 198 return success(); 199 } 200 }; 201 202 struct VectorInsertStridedSliceOpConvert final 203 : public OpConversionPattern<vector::InsertStridedSliceOp> { 204 using OpConversionPattern::OpConversionPattern; 205 206 LogicalResult 207 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, 208 ConversionPatternRewriter &rewriter) const override { 209 Value srcVector = adaptor.getOperands().front(); 210 Value dstVector = adaptor.getOperands().back(); 211 212 // Insert scalar values not supported yet. 213 if (srcVector.getType().isa<spirv::ScalarType>() || 214 dstVector.getType().isa<spirv::ScalarType>()) 215 return failure(); 216 217 uint64_t stride = getFirstIntValue(insertOp.strides()); 218 if (stride != 1) 219 return failure(); 220 221 uint64_t totalSize = 222 dstVector.getType().cast<VectorType>().getNumElements(); 223 uint64_t insertSize = 224 srcVector.getType().cast<VectorType>().getNumElements(); 225 uint64_t offset = getFirstIntValue(insertOp.offsets()); 226 227 SmallVector<int32_t, 2> indices(totalSize); 228 std::iota(indices.begin(), indices.end(), 0); 229 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 230 totalSize); 231 232 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 233 insertOp, dstVector.getType(), dstVector, srcVector, 234 rewriter.getI32ArrayAttr(indices)); 235 236 return success(); 237 } 238 }; 239 240 } // namespace 241 242 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 243 RewritePatternSet &patterns) { 244 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 245 VectorExtractElementOpConvert, VectorExtractOpConvert, 246 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 247 VectorInsertElementOpConvert, VectorInsertOpConvert, 248 VectorInsertStridedSliceOpConvert>(typeConverter, 249 patterns.getContext()); 250 } 251