1 //===- VectorToSPIRV.cpp - Vector to SPIR-V Patterns ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Vector dialect to SPIRV dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" 14 15 #include "../PassDetail.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 #include "mlir/Dialect/Vector/IR/VectorOps.h" 21 #include "mlir/Transforms/DialectConversion.h" 22 #include <numeric> 23 24 using namespace mlir; 25 26 /// Gets the first integer value from `attr`, assuming it is an integer array 27 /// attribute. 28 static uint64_t getFirstIntValue(ArrayAttr attr) { 29 return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue(); 30 } 31 32 namespace { 33 34 struct VectorBitcastConvert final 35 : public OpConversionPattern<vector::BitCastOp> { 36 using OpConversionPattern::OpConversionPattern; 37 38 LogicalResult 39 matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, 40 ConversionPatternRewriter &rewriter) const override { 41 auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); 42 if (!dstType) 43 return failure(); 44 45 if (dstType == adaptor.source().getType()) 46 rewriter.replaceOp(bitcastOp, adaptor.source()); 47 else 48 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(bitcastOp, dstType, 49 adaptor.source()); 50 51 return success(); 52 } 53 }; 54 55 struct VectorBroadcastConvert final 56 : public OpConversionPattern<vector::BroadcastOp> { 57 using OpConversionPattern::OpConversionPattern; 58 59 LogicalResult 60 matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, 61 ConversionPatternRewriter &rewriter) const override { 62 if (broadcastOp.source().getType().isa<VectorType>() || 63 !spirv::CompositeType::isValid(broadcastOp.getVectorType())) 64 return failure(); 65 SmallVector<Value, 4> source(broadcastOp.getVectorType().getNumElements(), 66 adaptor.source()); 67 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( 68 broadcastOp, broadcastOp.getVectorType(), source); 69 return success(); 70 } 71 }; 72 73 struct VectorExtractOpConvert final 74 : public OpConversionPattern<vector::ExtractOp> { 75 using OpConversionPattern::OpConversionPattern; 76 77 LogicalResult 78 matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, 79 ConversionPatternRewriter &rewriter) const override { 80 // Only support extracting a scalar value now. 81 VectorType resultVectorType = extractOp.getType().dyn_cast<VectorType>(); 82 if (resultVectorType && resultVectorType.getNumElements() > 1) 83 return failure(); 84 85 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 86 if (!dstType) 87 return failure(); 88 89 if (adaptor.vector().getType().isa<spirv::ScalarType>()) { 90 rewriter.replaceOp(extractOp, adaptor.vector()); 91 return success(); 92 } 93 94 int32_t id = getFirstIntValue(extractOp.position()); 95 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( 96 extractOp, adaptor.vector(), id); 97 return success(); 98 } 99 }; 100 101 struct VectorExtractStridedSliceOpConvert final 102 : public OpConversionPattern<vector::ExtractStridedSliceOp> { 103 using OpConversionPattern::OpConversionPattern; 104 105 LogicalResult 106 matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, 107 ConversionPatternRewriter &rewriter) const override { 108 auto dstType = getTypeConverter()->convertType(extractOp.getType()); 109 if (!dstType) 110 return failure(); 111 112 113 uint64_t offset = getFirstIntValue(extractOp.offsets()); 114 uint64_t size = getFirstIntValue(extractOp.sizes()); 115 uint64_t stride = getFirstIntValue(extractOp.strides()); 116 if (stride != 1) 117 return failure(); 118 119 Value srcVector = adaptor.getOperands().front(); 120 121 // Extract vector<1xT> case. 122 if (dstType.isa<spirv::ScalarType>()) { 123 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(extractOp, 124 srcVector, offset); 125 return success(); 126 } 127 128 SmallVector<int32_t, 2> indices(size); 129 std::iota(indices.begin(), indices.end(), offset); 130 131 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 132 extractOp, dstType, srcVector, srcVector, 133 rewriter.getI32ArrayAttr(indices)); 134 135 return success(); 136 } 137 }; 138 139 struct VectorFmaOpConvert final : public OpConversionPattern<vector::FMAOp> { 140 using OpConversionPattern::OpConversionPattern; 141 142 LogicalResult 143 matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, 144 ConversionPatternRewriter &rewriter) const override { 145 if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) 146 return failure(); 147 rewriter.replaceOpWithNewOp<spirv::GLSLFmaOp>( 148 fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); 149 return success(); 150 } 151 }; 152 153 struct VectorInsertOpConvert final 154 : public OpConversionPattern<vector::InsertOp> { 155 using OpConversionPattern::OpConversionPattern; 156 157 LogicalResult 158 matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, 159 ConversionPatternRewriter &rewriter) const override { 160 // Special case for inserting scalar values into size-1 vectors. 161 if (insertOp.getSourceType().isIntOrFloat() && 162 insertOp.getDestVectorType().getNumElements() == 1) { 163 rewriter.replaceOp(insertOp, adaptor.source()); 164 return success(); 165 } 166 167 if (insertOp.getSourceType().isa<VectorType>() || 168 !spirv::CompositeType::isValid(insertOp.getDestVectorType())) 169 return failure(); 170 int32_t id = getFirstIntValue(insertOp.position()); 171 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 172 insertOp, adaptor.source(), adaptor.dest(), id); 173 return success(); 174 } 175 }; 176 177 struct VectorExtractElementOpConvert final 178 : public OpConversionPattern<vector::ExtractElementOp> { 179 using OpConversionPattern::OpConversionPattern; 180 181 LogicalResult 182 matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, 183 ConversionPatternRewriter &rewriter) const override { 184 if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) 185 return failure(); 186 rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( 187 extractElementOp, extractElementOp.getType(), adaptor.vector(), 188 extractElementOp.position()); 189 return success(); 190 } 191 }; 192 193 struct VectorInsertElementOpConvert final 194 : public OpConversionPattern<vector::InsertElementOp> { 195 using OpConversionPattern::OpConversionPattern; 196 197 LogicalResult 198 matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override { 200 if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) 201 return failure(); 202 rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( 203 insertElementOp, insertElementOp.getType(), insertElementOp.dest(), 204 adaptor.source(), insertElementOp.position()); 205 return success(); 206 } 207 }; 208 209 struct VectorInsertStridedSliceOpConvert final 210 : public OpConversionPattern<vector::InsertStridedSliceOp> { 211 using OpConversionPattern::OpConversionPattern; 212 213 LogicalResult 214 matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, 215 ConversionPatternRewriter &rewriter) const override { 216 Value srcVector = adaptor.getOperands().front(); 217 Value dstVector = adaptor.getOperands().back(); 218 219 uint64_t stride = getFirstIntValue(insertOp.strides()); 220 if (stride != 1) 221 return failure(); 222 uint64_t offset = getFirstIntValue(insertOp.offsets()); 223 224 if (srcVector.getType().isa<spirv::ScalarType>()) { 225 assert(!dstVector.getType().isa<spirv::ScalarType>()); 226 rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( 227 insertOp, dstVector.getType(), srcVector, dstVector, 228 rewriter.getI32ArrayAttr(offset)); 229 return success(); 230 } 231 232 uint64_t totalSize = 233 dstVector.getType().cast<VectorType>().getNumElements(); 234 uint64_t insertSize = 235 srcVector.getType().cast<VectorType>().getNumElements(); 236 237 SmallVector<int32_t, 2> indices(totalSize); 238 std::iota(indices.begin(), indices.end(), 0); 239 std::iota(indices.begin() + offset, indices.begin() + offset + insertSize, 240 totalSize); 241 242 rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>( 243 insertOp, dstVector.getType(), dstVector, srcVector, 244 rewriter.getI32ArrayAttr(indices)); 245 246 return success(); 247 } 248 }; 249 250 class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> { 251 public: 252 using OpConversionPattern<vector::SplatOp>::OpConversionPattern; 253 254 LogicalResult 255 matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, 256 ConversionPatternRewriter &rewriter) const override { 257 VectorType dstVecType = op.getType(); 258 if (!spirv::CompositeType::isValid(dstVecType)) 259 return failure(); 260 SmallVector<Value, 4> source(dstVecType.getNumElements(), adaptor.input()); 261 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(op, dstVecType, 262 source); 263 return success(); 264 } 265 }; 266 267 } // namespace 268 269 void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 270 RewritePatternSet &patterns) { 271 patterns.add<VectorBitcastConvert, VectorBroadcastConvert, 272 VectorExtractElementOpConvert, VectorExtractOpConvert, 273 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, 274 VectorInsertElementOpConvert, VectorInsertOpConvert, 275 VectorInsertStridedSliceOpConvert, VectorSplatPattern>( 276 typeConverter, patterns.getContext()); 277 } 278